Skip to main content

somatize_runtime/
executor.rs

1//! Plan executor — walks [`ExecutionPlan`] trees and runs filter nodes.
2//!
3//! Handles sequential, parallel (scoped threads), cached, remote, loop,
4//! and branch execution. Uses [`GraphInfo`] for topology-aware input resolution.
5
6use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19/// Graph topology information for input resolution.
20///
21/// Maps each node to its predecessor node IDs so the executor knows
22/// where to read inputs from in the context store.
23#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25    /// node_id → list of predecessor node IDs
26    predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Register predecessors for a node.
35    pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36        self.predecessors.insert(node_id.into(), preds);
37    }
38
39    /// Build GraphInfo from a somatize_core::graph::Graph.
40    pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41        let mut info = Self::new();
42        for node in &graph.nodes {
43            let preds: Vec<String> = graph
44                .predecessors(&node.id)
45                .into_iter()
46                .map(|s| s.to_string())
47                .collect();
48            info.set_predecessors(node.id.clone(), preds);
49        }
50        info
51    }
52
53    /// Build GraphInfo for a linear pipeline (each node depends on the previous).
54    pub fn for_linear(node_ids: &[&str]) -> Self {
55        let mut info = Self::new();
56        for (i, &id) in node_ids.iter().enumerate() {
57            let preds = if i > 0 {
58                vec![node_ids[i - 1].to_string()]
59            } else {
60                vec![]
61            };
62            info.set_predecessors(id, preds);
63        }
64        info
65    }
66
67    /// Get predecessors for a node.
68    pub fn predecessors(&self, node_id: &str) -> &[String] {
69        self.predecessors
70            .get(node_id)
71            .map(|v| v.as_slice())
72            .unwrap_or(&[])
73    }
74}
75
76/// Trait for executing plan nodes on remote workers.
77///
78/// When set on Context, `ExecutionPlan::Remote` nodes delegate to this
79/// instead of executing locally. The implementation sends the sub-plan
80/// to a worker and returns the result.
81pub trait RemoteExecutor: Send + Sync {
82    /// Execute a sub-plan remotely and return the output value.
83    fn execute_remote(
84        &self,
85        node_id: &str,
86        target: &somatize_core::filter::RemoteTarget,
87        input: Option<&Value>,
88    ) -> Result<Value>;
89}
90
91/// Execution context passed to filters during runtime.
92///
93/// Node outputs are stored as [`VirtualValue`]s — they may be materialized
94/// in memory, cached on disk, or deferred (not yet computed). The executor
95/// resolves them on demand when a downstream node needs the data.
96pub struct Context {
97    /// Node outputs as virtual values (may be lazy).
98    pub store: HashMap<String, VirtualValue>,
99    /// Event bus for emitting runtime events.
100    pub event_bus: Arc<EventBus>,
101    /// Current run ID.
102    pub run_id: String,
103    /// Track execution order.
104    pub execution_order: Vec<String>,
105    /// Graph topology for input resolution.
106    pub graph_info: GraphInfo,
107    /// Optional remote executor for distributed plans.
108    pub remote_executor: Option<Arc<dyn RemoteExecutor>>,
109    /// Optional data store for persisting intermediate results.
110    pub data_store: Option<Arc<dyn DataStore>>,
111    /// Minimum value size (bytes) to spill to DataStore instead of keeping in memory.
112    /// Default: 0 (disabled — all values stay in memory).
113    pub spill_threshold: usize,
114}
115
116impl Context {
117    pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
118        Self {
119            store: HashMap::new(),
120            event_bus,
121            run_id: run_id.into(),
122            execution_order: Vec::new(),
123            graph_info: GraphInfo::new(),
124            remote_executor: None,
125            data_store: None,
126            spill_threshold: 0,
127        }
128    }
129
130    pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
131        self.graph_info = info;
132        self
133    }
134
135    pub fn with_remote_executor(mut self, executor: Arc<dyn RemoteExecutor>) -> Self {
136        self.remote_executor = Some(executor);
137        self
138    }
139
140    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
141        self.data_store = Some(store);
142        self
143    }
144
145    /// Set spill threshold: values larger than this (in bytes) are offloaded
146    /// to the DataStore and replaced with a VirtualValue::Cached reference.
147    /// Requires a DataStore to be set via `with_data_store()`.
148    pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
149        self.spill_threshold = bytes;
150        self
151    }
152
153    /// If a DataStore and spill threshold are configured, check if the value
154    /// should be offloaded. Returns VirtualValue (materialized or cached ref).
155    fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
156        if self.spill_threshold > 0
157            && let Some(store) = &self.data_store
158        {
159            let size = value.size() * 8; // approximate bytes (f64 = 8 bytes)
160            if size >= self.spill_threshold {
161                let key = somatize_core::cache::CacheKey::from_parts(&[
162                    self.run_id.as_bytes(),
163                    node_id.as_bytes(),
164                ]);
165                let vv_for_schema = VirtualValue::materialized(value.clone());
166                let schema = vv_for_schema.schema().clone();
167                if let Ok(_data_ref) = store.put(&key, &value) {
168                    tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
169                    return VirtualValue::cached(key, schema);
170                }
171            }
172        }
173        VirtualValue::materialized(value)
174    }
175
176    /// Get the materialized Value for a node, if present and materialized.
177    pub fn get(&self, node_id: &str) -> Option<&Value> {
178        self.store.get(node_id).and_then(|vv| vv.as_value())
179    }
180
181    /// Get the raw VirtualValue for a node.
182    pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
183        self.store.get(node_id)
184    }
185
186    /// Store a materialized value for a node.
187    pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
188        let id = node_id.into();
189        self.execution_order.push(id.clone());
190        self.store.insert(id, VirtualValue::materialized(value));
191    }
192
193    /// Store a virtual value (which may be deferred or cached).
194    pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
195        let id = node_id.into();
196        self.execution_order.push(id.clone());
197        self.store.insert(id, vv);
198    }
199
200    fn snapshot(&self) -> Self {
201        Self {
202            store: self.store.clone(),
203            event_bus: self.event_bus.clone(),
204            run_id: self.run_id.clone(),
205            execution_order: self.execution_order.clone(),
206            graph_info: self.graph_info.clone(),
207            remote_executor: self.remote_executor.clone(),
208            data_store: self.data_store.clone(),
209            spill_threshold: self.spill_threshold,
210        }
211    }
212}
213
214/// Execute a compiled plan synchronously.
215/// For parallel branches, uses the async executor under the hood.
216/// Contract for executing a plan.
217pub trait Executable {
218    fn execute(
219        &self,
220        ctx: &mut Context,
221        filters: &FilterLibrary,
222        cache: &dyn CacheStore,
223    ) -> Result<()>;
224}
225
226impl Executable for ExecutionPlan {
227    fn execute(
228        &self,
229        ctx: &mut Context,
230        filters: &FilterLibrary,
231        cache: &dyn CacheStore,
232    ) -> Result<()> {
233        match self {
234            ExecutionPlan::Empty => Ok(()),
235
236            ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
237
238            ExecutionPlan::Cached { node_id, key } => {
239                let start = Instant::now();
240                let value = cache.get(key)?.ok_or_else(|| {
241                    SomaError::Cache(format!(
242                        "expected cached value for node `{node_id}` not found"
243                    ))
244                })?;
245                ctx.set(node_id.clone(), value);
246                ctx.event_bus.emit(Event::NodeCacheHit {
247                    run_id: ctx.run_id.clone(),
248                    node_id: node_id.clone(),
249                    key: key.clone(),
250                    tier: somatize_core::cache::CacheTier::Memory,
251                    load_time: start.elapsed(),
252                });
253                Ok(())
254            }
255
256            ExecutionPlan::Sequence(steps) => {
257                for step in steps {
258                    step.execute(ctx, filters, cache)?;
259                }
260                Ok(())
261            }
262
263            ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
264
265            ExecutionPlan::Loop {
266                node_id,
267                body,
268                max_iterations,
269            } => {
270                let max = max_iterations.unwrap_or(100);
271                for i in 0..max {
272                    body.execute(ctx, filters, cache)?;
273
274                    // Check termination: if the last executed node produced a Value
275                    // that indicates "done" (true, "done", "stop", or empty), break.
276                    let should_stop = ctx
277                        .execution_order
278                        .last()
279                        .and_then(|last_id| ctx.get(last_id))
280                        .map(|v| match v {
281                            Value::Json(j) => {
282                                j.as_bool() == Some(true)
283                                    || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
284                                    || j.get("done").and_then(|d| d.as_bool()) == Some(true)
285                            }
286                            Value::Empty => true,
287                            _ => false,
288                        })
289                        .unwrap_or(false);
290
291                    if should_stop {
292                        ctx.event_bus.emit(Event::NodeCompleted {
293                            run_id: ctx.run_id.clone(),
294                            node_id: node_id.clone(),
295                            duration: std::time::Duration::ZERO,
296                            output_summary: format!("Loop terminated at iteration {}", i + 1),
297                        });
298                        break;
299                    }
300                }
301                Ok(())
302            }
303
304            ExecutionPlan::Branch { node_id, arms } => {
305                // Execute the branch node first (it produces the condition value)
306                execute_node(node_id, ctx, filters, cache)?;
307
308                // Get the condition result
309                let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
310
311                // Match against arm labels
312                let selected_arm = match &condition {
313                    Value::Json(j) => {
314                        // Try matching by string value, bool, or "branch" field
315                        let selector = j
316                            .as_str()
317                            .map(String::from)
318                            .or_else(|| j.as_bool().map(|b| b.to_string()))
319                            .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
320                            .unwrap_or_else(|| "true".to_string());
321
322                        arms.iter()
323                            .find(|(label, _)| label == &selector)
324                            .or_else(|| {
325                                arms.iter()
326                                    .find(|(label, _)| label == "default" || label == "else")
327                            })
328                            .or_else(|| arms.first())
329                    }
330                    _ => arms.first(),
331                };
332
333                if let Some((label, plan)) = selected_arm {
334                    ctx.event_bus.emit(Event::NodeCompleted {
335                        run_id: ctx.run_id.clone(),
336                        node_id: node_id.clone(),
337                        duration: std::time::Duration::ZERO,
338                        output_summary: format!("Branch selected: {label}"),
339                    });
340                    plan.execute(ctx, filters, cache)?;
341                }
342                Ok(())
343            }
344
345            ExecutionPlan::Remote {
346                node_id,
347                target,
348                plan,
349            } => {
350                if let Some(remote) = &ctx.remote_executor {
351                    // Gather input from predecessors
352                    let input = ctx
353                        .graph_info
354                        .predecessors(node_id)
355                        .first()
356                        .and_then(|pred| ctx.get(pred));
357
358                    let result = remote.execute_remote(node_id, target, input)?;
359                    ctx.set(node_id.clone(), result);
360                    ctx.execution_order.push(node_id.clone());
361                    Ok(())
362                } else {
363                    // No remote executor — fall back to local execution
364                    plan.execute(ctx, filters, cache)
365                }
366            }
367
368            ExecutionPlan::Composite { node_ids } => {
369                // Sequential fallback — execute each node in order.
370                // A future Python-aware executor will pass tensors directly.
371                for nid in node_ids {
372                    execute_node(nid, ctx, filters, cache)?;
373                }
374                Ok(())
375            }
376
377            _ => {
378                tracing::warn!("Unhandled ExecutionPlan variant");
379                Ok(())
380            }
381        }
382    }
383}
384
385/// Execute a plan (convenience function, delegates to `Executable` trait).
386pub fn execute(
387    plan: &ExecutionPlan,
388    ctx: &mut Context,
389    filters: &FilterLibrary,
390    cache: &dyn CacheStore,
391) -> Result<()> {
392    plan.execute(ctx, filters, cache)
393}
394
395/// Execute a single filter node.
396fn execute_node(
397    node_id: &str,
398    ctx: &mut Context,
399    filters: &FilterLibrary,
400    _cache: &dyn CacheStore,
401) -> Result<()> {
402    let start = Instant::now();
403
404    let filter = filters
405        .get(node_id)
406        .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
407
408    ctx.event_bus.emit(Event::NodeStarted {
409        run_id: ctx.run_id.clone(),
410        node_id: node_id.to_string(),
411        kind: filter.meta().kind,
412    });
413
414    let input = resolve_input(node_id, ctx);
415    let state = filters.get_state(node_id).cloned().unwrap_or(Value::Empty);
416    let result = filter.forward(&input, &state);
417
418    match result {
419        Ok(output) => {
420            let duration = start.elapsed();
421            let summary = format!("{output}");
422            let vv = ctx.maybe_spill(node_id, output);
423            ctx.set_virtual(node_id, vv);
424            ctx.event_bus.emit(Event::NodeCompleted {
425                run_id: ctx.run_id.clone(),
426                node_id: node_id.to_string(),
427                duration,
428                output_summary: summary,
429            });
430            Ok(())
431        }
432        Err(e) => {
433            ctx.event_bus.emit(Event::NodeFailed {
434                run_id: ctx.run_id.clone(),
435                node_id: node_id.to_string(),
436                error: e.to_string(),
437            });
438            Err(e)
439        }
440    }
441}
442
443/// Execute parallel branches concurrently using std::thread::scope.
444///
445/// Each branch gets a snapshot of the context. After all branches complete,
446/// their new outputs are merged back into the main context.
447fn execute_parallel(
448    branches: &[ExecutionPlan],
449    ctx: &mut Context,
450    filters: &FilterLibrary,
451    cache: &dyn CacheStore,
452) -> Result<()> {
453    let snapshot_keys: Arc<std::collections::HashSet<String>> =
454        Arc::new(ctx.store.keys().cloned().collect());
455
456    // Use scoped threads for true parallelism without Send requirements
457    let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
458        let handles: Vec<_> = branches
459            .iter()
460            .map(|branch| {
461                let mut branch_ctx = ctx.snapshot();
462                let keys = snapshot_keys.clone();
463                s.spawn(move || {
464                    execute(branch, &mut branch_ctx, filters, cache)?;
465                    let new_entries: Vec<(String, VirtualValue)> = branch_ctx
466                        .store
467                        .into_iter()
468                        .filter(|(k, _)| !keys.contains(k))
469                        .collect();
470                    Ok(new_entries)
471                })
472            })
473            .collect();
474
475        handles.into_iter().map(|h| h.join().unwrap()).collect()
476    });
477
478    // Merge results and propagate first error
479    for result in results {
480        let entries = result?;
481        for (key, vv) in entries {
482            ctx.set_virtual(key, vv);
483        }
484    }
485
486    Ok(())
487}
488
489/// Resolve a VirtualValue to a concrete Value, loading from DataStore if needed.
490fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
491    match vv {
492        VirtualValue::Materialized { value, .. } => Some(value.clone()),
493        VirtualValue::Cached { key, .. } => {
494            // Try to load from DataStore
495            if let Some(store) = data_store {
496                let data_ref = somatize_core::store::DataRef::Cached {
497                    cache_key: key.clone(),
498                };
499                store.get(&data_ref).ok()
500            } else {
501                None
502            }
503        }
504        _ => None,
505    }
506}
507
508/// Resolve the input for a node from the context store using graph topology.
509/// If a predecessor was spilled to DataStore, loads it back.
510pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
511    let preds = ctx.graph_info.predecessors(node_id);
512
513    let resolve_node = |id: &str| -> Option<Value> {
514        ctx.store
515            .get(id)
516            .and_then(|vv| resolve_value(vv, &ctx.data_store))
517    };
518
519    match preds.len() {
520        0 => ctx
521            .execution_order
522            .last()
523            .and_then(|id| resolve_node(id))
524            .unwrap_or(Value::Empty),
525        1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
526        _ => {
527            let mut merged = serde_json::Map::new();
528            for pred_id in preds {
529                if let Some(val) = resolve_node(pred_id) {
530                    let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
531                    merged.insert(pred_id.clone(), json_val);
532                }
533            }
534            Value::Json(serde_json::Value::Object(merged))
535        }
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542    use crate::cache::MemoryCache;
543    use somatize_core::cache::CacheKey;
544    use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
545
546    struct DoublerFilter;
547
548    impl Filter for DoublerFilter {
549        fn config_hash(&self) -> CacheKey {
550            CacheKey::from_parts(&[b"Doubler"])
551        }
552        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
553            Ok(Value::Empty)
554        }
555        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
556            match x {
557                Value::Tensor { values, shape } => {
558                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
559                    Ok(Value::tensor(doubled, shape.clone()))
560                }
561                _ => Ok(x.clone()),
562            }
563        }
564        fn meta(&self) -> FilterMeta {
565            FilterMeta {
566                name: "Doubler".into(),
567                kind: FilterKind::Stateless,
568                cacheable: true,
569                differentiable: true,
570                stream_mode: StreamMode::FixedState,
571                distribution: somatize_core::filter::Distribution::Local,
572                input_schema: None,
573                output_schema: None,
574            }
575        }
576
577        fn as_any(&self) -> &dyn std::any::Any {
578            self
579        }
580    }
581
582    struct AdderFilter {
583        amount: f64,
584    }
585
586    impl Filter for AdderFilter {
587        fn config_hash(&self) -> CacheKey {
588            CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
589        }
590        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
591            Ok(Value::Empty)
592        }
593        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
594            match x {
595                Value::Tensor { values, shape } => {
596                    let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
597                    Ok(Value::tensor(added, shape.clone()))
598                }
599                _ => Ok(x.clone()),
600            }
601        }
602        fn meta(&self) -> FilterMeta {
603            FilterMeta {
604                name: "Adder".into(),
605                kind: FilterKind::Stateless,
606                cacheable: true,
607                differentiable: true,
608                stream_mode: StreamMode::FixedState,
609                distribution: somatize_core::filter::Distribution::Local,
610                input_schema: None,
611                output_schema: None,
612            }
613        }
614
615        fn as_any(&self) -> &dyn std::any::Any {
616            self
617        }
618    }
619
620    /// Slow filter that sleeps to verify parallelism.
621    struct SlowFilter {
622        id: String,
623        delay_ms: u64,
624    }
625
626    impl Filter for SlowFilter {
627        fn config_hash(&self) -> CacheKey {
628            CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
629        }
630        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
631            Ok(Value::Empty)
632        }
633        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
634            std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
635            Ok(x.clone())
636        }
637        fn meta(&self) -> FilterMeta {
638            FilterMeta {
639                name: format!("Slow_{}", self.id),
640                kind: FilterKind::Stateless,
641                cacheable: false,
642                differentiable: true,
643                stream_mode: StreamMode::FixedState,
644                distribution: somatize_core::filter::Distribution::Local,
645                input_schema: None,
646                output_schema: None,
647            }
648        }
649
650        fn as_any(&self) -> &dyn std::any::Any {
651            self
652        }
653    }
654
655    fn setup() -> (Arc<EventBus>, MemoryCache) {
656        (Arc::new(EventBus::new(64)), MemoryCache::default())
657    }
658
659    #[test]
660    fn execute_single_node() {
661        let (bus, cache) = setup();
662        let mut ctx = Context::new(bus, "run_1");
663        ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
664        ctx.graph_info
665            .set_predecessors("doubler", vec!["input".into()]);
666
667        let mut filters = FilterLibrary::new();
668        filters.register("doubler", Box::new(DoublerFilter));
669
670        let plan = ExecutionPlan::Execute {
671            node_id: "doubler".into(),
672        };
673
674        execute(&plan, &mut ctx, &filters, &cache).unwrap();
675
676        let result = ctx.get("doubler").unwrap();
677        let (data, _) = result.as_tensor().unwrap();
678        assert_eq!(data, &[2.0, 4.0, 6.0]);
679    }
680
681    #[test]
682    fn execute_sequence_with_graph_info() {
683        let (bus, cache) = setup();
684        let mut ctx = Context::new(bus, "run_1");
685        ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
686
687        let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
688        ctx.graph_info = graph_info;
689
690        let mut filters = FilterLibrary::new();
691        filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
692        filters.register("double", Box::new(DoublerFilter));
693
694        let plan = ExecutionPlan::Sequence(vec![
695            ExecutionPlan::Execute {
696                node_id: "add".into(),
697            },
698            ExecutionPlan::Execute {
699                node_id: "double".into(),
700            },
701        ]);
702
703        execute(&plan, &mut ctx, &filters, &cache).unwrap();
704
705        let result = ctx.get("double").unwrap();
706        let (data, _) = result.as_tensor().unwrap();
707        assert_eq!(data, &[22.0, 24.0]);
708    }
709
710    #[test]
711    fn execute_cached_node() {
712        let (bus, cache) = setup();
713        let key = CacheKey::hash_data(b"cached_result");
714        let cached_value = Value::tensor(vec![99.0], vec![1]);
715        cache.put(&key, &cached_value).unwrap();
716
717        let mut ctx = Context::new(bus, "run_1");
718        let filters = FilterLibrary::new();
719
720        let plan = ExecutionPlan::Cached {
721            node_id: "cached_node".into(),
722            key,
723        };
724
725        execute(&plan, &mut ctx, &filters, &cache).unwrap();
726        assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
727    }
728
729    #[test]
730    fn execute_emits_events() {
731        let bus = Arc::new(EventBus::new(64));
732        let cache = MemoryCache::default();
733        let mut rx = bus.subscribe();
734
735        let mut ctx = Context::new(bus, "run_1");
736        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
737        ctx.graph_info
738            .set_predecessors("double", vec!["input".into()]);
739
740        let mut filters = FilterLibrary::new();
741        filters.register("double", Box::new(DoublerFilter));
742
743        execute(
744            &ExecutionPlan::Execute {
745                node_id: "double".into(),
746            },
747            &mut ctx,
748            &filters,
749            &cache,
750        )
751        .unwrap();
752
753        let e1 = rx.try_recv().unwrap();
754        assert!(matches!(e1, Event::NodeStarted { .. }));
755        let e2 = rx.try_recv().unwrap();
756        assert!(matches!(e2, Event::NodeCompleted { .. }));
757    }
758
759    #[test]
760    fn execute_missing_filter_errors() {
761        let (bus, cache) = setup();
762        let mut ctx = Context::new(bus, "run_1");
763        let filters = FilterLibrary::new();
764
765        let result = execute(
766            &ExecutionPlan::Execute {
767                node_id: "nonexistent".into(),
768            },
769            &mut ctx,
770            &filters,
771            &cache,
772        );
773        assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
774    }
775
776    #[test]
777    fn execute_empty_plan() {
778        let (bus, cache) = setup();
779        let mut ctx = Context::new(bus, "run_1");
780        let filters = FilterLibrary::new();
781        execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
782    }
783
784    #[test]
785    fn execute_parallel_branches_merge_outputs() {
786        let (bus, cache) = setup();
787        let mut ctx = Context::new(bus, "run_1");
788        ctx.set("input", Value::tensor(vec![5.0], vec![1]));
789        ctx.graph_info
790            .set_predecessors("double", vec!["input".into()]);
791        ctx.graph_info.set_predecessors("add", vec!["input".into()]);
792
793        let mut filters = FilterLibrary::new();
794        filters.register("double", Box::new(DoublerFilter));
795        filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
796
797        let plan = ExecutionPlan::Parallel(vec![
798            ExecutionPlan::Execute {
799                node_id: "double".into(),
800            },
801            ExecutionPlan::Execute {
802                node_id: "add".into(),
803            },
804        ]);
805
806        execute(&plan, &mut ctx, &filters, &cache).unwrap();
807
808        let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
809        assert_eq!(double_out, &[10.0]);
810        let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
811        assert_eq!(add_out, &[105.0]);
812    }
813
814    #[test]
815    fn parallel_branches_run_concurrently() {
816        let (bus, cache) = setup();
817        let mut ctx = Context::new(bus, "run_1");
818        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
819        ctx.graph_info
820            .set_predecessors("slow_a", vec!["input".into()]);
821        ctx.graph_info
822            .set_predecessors("slow_b", vec!["input".into()]);
823
824        let mut filters = FilterLibrary::new();
825        filters.register(
826            "slow_a",
827            Box::new(SlowFilter {
828                id: "a".into(),
829                delay_ms: 50,
830            }),
831        );
832        filters.register(
833            "slow_b",
834            Box::new(SlowFilter {
835                id: "b".into(),
836                delay_ms: 50,
837            }),
838        );
839
840        let plan = ExecutionPlan::Parallel(vec![
841            ExecutionPlan::Execute {
842                node_id: "slow_a".into(),
843            },
844            ExecutionPlan::Execute {
845                node_id: "slow_b".into(),
846            },
847        ]);
848
849        let start = Instant::now();
850        execute(&plan, &mut ctx, &filters, &cache).unwrap();
851        let elapsed = start.elapsed();
852
853        // If truly parallel: ~50ms. If sequential: ~100ms.
854        // Use 90ms as threshold to account for overhead.
855        assert!(
856            elapsed.as_millis() < 90,
857            "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
858            elapsed.as_millis()
859        );
860
861        assert!(ctx.get("slow_a").is_some());
862        assert!(ctx.get("slow_b").is_some());
863    }
864
865    #[test]
866    fn resolve_input_single_predecessor() {
867        let bus = Arc::new(EventBus::new(8));
868        let mut ctx = Context::new(bus, "r");
869        ctx.set("A", Value::tensor(vec![42.0], vec![1]));
870        ctx.graph_info.set_predecessors("B", vec!["A".into()]);
871
872        let input = resolve_input("B", &ctx);
873        let (data, _) = input.as_tensor().unwrap();
874        assert_eq!(data, &[42.0]);
875    }
876
877    #[test]
878    fn resolve_input_multiple_predecessors() {
879        let bus = Arc::new(EventBus::new(8));
880        let mut ctx = Context::new(bus, "r");
881        ctx.set("A", Value::tensor(vec![1.0], vec![1]));
882        ctx.set("B", Value::tensor(vec![2.0], vec![1]));
883        ctx.graph_info
884            .set_predecessors("C", vec!["A".into(), "B".into()]);
885
886        let input = resolve_input("C", &ctx);
887        let json = input.as_json().unwrap();
888        assert!(json.get("A").is_some());
889        assert!(json.get("B").is_some());
890    }
891
892    #[test]
893    fn resolve_input_no_predecessors_fallback() {
894        let bus = Arc::new(EventBus::new(8));
895        let mut ctx = Context::new(bus, "r");
896        ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
897
898        let input = resolve_input("root", &ctx);
899        let (data, _) = input.as_tensor().unwrap();
900        assert_eq!(data, &[7.0]);
901    }
902
903    #[test]
904    fn graph_info_from_linear() {
905        let info = GraphInfo::for_linear(&["a", "b", "c"]);
906        assert!(info.predecessors("a").is_empty());
907        assert_eq!(info.predecessors("b"), &["a"]);
908        assert_eq!(info.predecessors("c"), &["b"]);
909    }
910}