Skip to main content

somatize_runtime/
executor.rs

1//! Plan executor — walks [`ExecutionPlan`] trees and runs filter nodes.
2//!
3//! Handles sequential, parallel (scoped threads), cached, remote, loop,
4//! and branch execution. Uses [`GraphInfo`] for topology-aware input resolution.
5
6use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19/// Graph topology information for input resolution.
20///
21/// Maps each node to its predecessor node IDs so the executor knows
22/// where to read inputs from in the context store.
23#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25    /// node_id → list of predecessor node IDs
26    predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Register predecessors for a node.
35    pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36        self.predecessors.insert(node_id.into(), preds);
37    }
38
39    /// Build GraphInfo from a somatize_core::graph::Graph.
40    pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41        let mut info = Self::new();
42        for node in &graph.nodes {
43            let preds: Vec<String> = graph
44                .predecessors(&node.id)
45                .into_iter()
46                .map(|s| s.to_string())
47                .collect();
48            info.set_predecessors(node.id.clone(), preds);
49        }
50        info
51    }
52
53    /// Build GraphInfo for a linear pipeline (each node depends on the previous).
54    pub fn for_linear(node_ids: &[&str]) -> Self {
55        let mut info = Self::new();
56        for (i, &id) in node_ids.iter().enumerate() {
57            let preds = if i > 0 {
58                vec![node_ids[i - 1].to_string()]
59            } else {
60                vec![]
61            };
62            info.set_predecessors(id, preds);
63        }
64        info
65    }
66
67    /// Get predecessors for a node.
68    pub fn predecessors(&self, node_id: &str) -> &[String] {
69        self.predecessors
70            .get(node_id)
71            .map(|v| v.as_slice())
72            .unwrap_or(&[])
73    }
74}
75
76/// Execution context passed to filters during runtime.
77///
78/// Node outputs are stored as [`VirtualValue`]s — they may be materialized
79/// in memory, cached on disk, or deferred (not yet computed). The executor
80/// resolves them on demand when a downstream node needs the data.
81pub struct Context {
82    /// Node outputs as virtual values (may be lazy).
83    pub store: HashMap<String, VirtualValue>,
84    /// Event bus for emitting runtime events.
85    pub event_bus: Arc<EventBus>,
86    /// Current run ID.
87    pub run_id: String,
88    /// Track execution order.
89    pub execution_order: Vec<String>,
90    /// Graph topology for input resolution.
91    pub graph_info: GraphInfo,
92    /// Optional transport for distributed plans.
93    pub transport: Option<Arc<dyn crate::runner::Transport>>,
94    /// Optional data store for persisting intermediate results.
95    pub data_store: Option<Arc<dyn DataStore>>,
96    /// Minimum value size (bytes) to spill to DataStore instead of keeping in memory.
97    /// Default: 0 (disabled — all values stay in memory).
98    pub spill_threshold: usize,
99}
100
101impl Context {
102    pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
103        Self {
104            store: HashMap::new(),
105            event_bus,
106            run_id: run_id.into(),
107            execution_order: Vec::new(),
108            graph_info: GraphInfo::new(),
109            transport: None,
110            data_store: None,
111            spill_threshold: 0,
112        }
113    }
114
115    pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
116        self.graph_info = info;
117        self
118    }
119
120    pub fn with_transport(mut self, transport: Arc<dyn crate::runner::Transport>) -> Self {
121        self.transport = Some(transport);
122        self
123    }
124
125    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
126        self.data_store = Some(store);
127        self
128    }
129
130    /// Set spill threshold: values larger than this (in bytes) are offloaded
131    /// to the DataStore and replaced with a VirtualValue::Cached reference.
132    /// Requires a DataStore to be set via `with_data_store()`.
133    pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
134        self.spill_threshold = bytes;
135        self
136    }
137
138    /// If a DataStore and spill threshold are configured, check if the value
139    /// should be offloaded. Returns VirtualValue (materialized or cached ref).
140    fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
141        if self.spill_threshold > 0
142            && let Some(store) = &self.data_store
143        {
144            let size = value.size() * 8; // approximate bytes (f64 = 8 bytes)
145            if size >= self.spill_threshold {
146                let key = somatize_core::cache::CacheKey::from_parts(&[
147                    self.run_id.as_bytes(),
148                    node_id.as_bytes(),
149                ]);
150                let vv_for_schema = VirtualValue::materialized(value.clone());
151                let schema = vv_for_schema.schema().clone();
152                if let Ok(_data_ref) = store.put(&key, &value) {
153                    tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
154                    return VirtualValue::cached(key, schema);
155                }
156            }
157        }
158        VirtualValue::materialized(value)
159    }
160
161    /// Get the materialized Value for a node, if present and materialized.
162    pub fn get(&self, node_id: &str) -> Option<&Value> {
163        self.store.get(node_id).and_then(|vv| vv.as_value())
164    }
165
166    /// Get the raw VirtualValue for a node.
167    pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
168        self.store.get(node_id)
169    }
170
171    /// Store a materialized value for a node.
172    pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
173        let id = node_id.into();
174        self.execution_order.push(id.clone());
175        self.store.insert(id, VirtualValue::materialized(value));
176    }
177
178    /// Store a virtual value (which may be deferred or cached).
179    pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
180        let id = node_id.into();
181        self.execution_order.push(id.clone());
182        self.store.insert(id, vv);
183    }
184
185    fn snapshot(&self) -> Self {
186        Self {
187            store: self.store.clone(),
188            event_bus: self.event_bus.clone(),
189            run_id: self.run_id.clone(),
190            execution_order: self.execution_order.clone(),
191            graph_info: self.graph_info.clone(),
192            transport: self.transport.clone(),
193            data_store: self.data_store.clone(),
194            spill_threshold: self.spill_threshold,
195        }
196    }
197}
198
199/// Execute a compiled plan synchronously.
200/// For parallel branches, uses the async executor under the hood.
201/// Contract for executing a plan.
202pub trait Executable {
203    fn execute(
204        &self,
205        ctx: &mut Context,
206        filters: &FilterLibrary,
207        cache: &dyn CacheStore,
208    ) -> Result<()>;
209}
210
211impl Executable for ExecutionPlan {
212    fn execute(
213        &self,
214        ctx: &mut Context,
215        filters: &FilterLibrary,
216        cache: &dyn CacheStore,
217    ) -> Result<()> {
218        match self {
219            ExecutionPlan::Empty => Ok(()),
220
221            ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
222
223            ExecutionPlan::Cached { node_id, key } => {
224                let start = Instant::now();
225                let value = cache.get(key)?.ok_or_else(|| {
226                    SomaError::Cache(format!(
227                        "expected cached value for node `{node_id}` not found"
228                    ))
229                })?;
230                ctx.set(node_id.clone(), value);
231                ctx.event_bus.emit(Event::NodeCacheHit {
232                    run_id: ctx.run_id.clone(),
233                    node_id: node_id.clone(),
234                    key: key.clone(),
235                    tier: somatize_core::cache::CacheTier::Memory,
236                    load_time: start.elapsed(),
237                });
238                Ok(())
239            }
240
241            ExecutionPlan::Sequence(steps) => {
242                for step in steps {
243                    step.execute(ctx, filters, cache)?;
244                }
245                Ok(())
246            }
247
248            ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
249
250            ExecutionPlan::Loop {
251                node_id,
252                body,
253                max_iterations,
254            } => {
255                let max = max_iterations.unwrap_or(100);
256                for i in 0..max {
257                    body.execute(ctx, filters, cache)?;
258
259                    // Check termination: if the last executed node produced a Value
260                    // that indicates "done" (true, "done", "stop", or empty), break.
261                    let should_stop = ctx
262                        .execution_order
263                        .last()
264                        .and_then(|last_id| ctx.get(last_id))
265                        .map(|v| match v {
266                            Value::Json(j) => {
267                                j.as_bool() == Some(true)
268                                    || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
269                                    || j.get("done").and_then(|d| d.as_bool()) == Some(true)
270                            }
271                            Value::Empty => true,
272                            _ => false,
273                        })
274                        .unwrap_or(false);
275
276                    if should_stop {
277                        ctx.event_bus.emit(Event::NodeCompleted {
278                            run_id: ctx.run_id.clone(),
279                            node_id: node_id.clone(),
280                            duration: std::time::Duration::ZERO,
281                            output_summary: format!("Loop terminated at iteration {}", i + 1),
282                        });
283                        break;
284                    }
285                }
286                Ok(())
287            }
288
289            ExecutionPlan::Branch { node_id, arms } => {
290                // Execute the branch node first (it produces the condition value)
291                execute_node(node_id, ctx, filters, cache)?;
292
293                // Get the condition result
294                let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
295
296                // Match against arm labels
297                let selected_arm = match &condition {
298                    Value::Json(j) => {
299                        // Try matching by string value, bool, or "branch" field
300                        let selector = j
301                            .as_str()
302                            .map(String::from)
303                            .or_else(|| j.as_bool().map(|b| b.to_string()))
304                            .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
305                            .unwrap_or_else(|| "true".to_string());
306
307                        arms.iter()
308                            .find(|(label, _)| label == &selector)
309                            .or_else(|| {
310                                arms.iter()
311                                    .find(|(label, _)| label == "default" || label == "else")
312                            })
313                            .or_else(|| arms.first())
314                    }
315                    _ => arms.first(),
316                };
317
318                if let Some((label, plan)) = selected_arm {
319                    ctx.event_bus.emit(Event::NodeCompleted {
320                        run_id: ctx.run_id.clone(),
321                        node_id: node_id.clone(),
322                        duration: std::time::Duration::ZERO,
323                        output_summary: format!("Branch selected: {label}"),
324                    });
325                    plan.execute(ctx, filters, cache)?;
326                }
327                Ok(())
328            }
329
330            ExecutionPlan::Remote {
331                node_id,
332                target: _,
333                plan,
334            } => {
335                if let Some(transport) = &ctx.transport {
336                    // Gather input from predecessors
337                    let input = ctx
338                        .graph_info
339                        .predecessors(node_id)
340                        .first()
341                        .and_then(|pred| ctx.get(pred));
342
343                    let result = transport.execute_node(node_id, input)?;
344                    ctx.set(node_id.clone(), result);
345                    Ok(())
346                } else {
347                    // No transport — fall back to local execution
348                    plan.execute(ctx, filters, cache)
349                }
350            }
351
352            ExecutionPlan::Composite { node_ids } => {
353                // Sequential fallback — execute each node in order.
354                // A future Python-aware executor will pass tensors directly.
355                for nid in node_ids {
356                    execute_node(nid, ctx, filters, cache)?;
357                }
358                Ok(())
359            }
360
361            _ => {
362                tracing::warn!("Unhandled ExecutionPlan variant");
363                Ok(())
364            }
365        }
366    }
367}
368
369/// Execute a plan (convenience function, delegates to `Executable` trait).
370pub fn execute(
371    plan: &ExecutionPlan,
372    ctx: &mut Context,
373    filters: &FilterLibrary,
374    cache: &dyn CacheStore,
375) -> Result<()> {
376    plan.execute(ctx, filters, cache)
377}
378
379/// Execute a single filter node.
380fn execute_node(
381    node_id: &str,
382    ctx: &mut Context,
383    filters: &FilterLibrary,
384    _cache: &dyn CacheStore,
385) -> Result<()> {
386    let start = Instant::now();
387
388    let filter = filters
389        .get(node_id)
390        .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
391
392    ctx.event_bus.emit(Event::NodeStarted {
393        run_id: ctx.run_id.clone(),
394        node_id: node_id.to_string(),
395        kind: filter.meta().kind,
396    });
397
398    let _span = tracing::info_span!("execute_node", %node_id).entered();
399
400    let input = resolve_input(node_id, ctx);
401    let state = filters.get_state(node_id).cloned().unwrap_or(Value::Empty);
402
403    // catch_unwind: a panic in a user filter must not crash the process
404    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
405        filter.forward(&input, &state)
406    }));
407
408    let result = match result {
409        Ok(inner) => inner,
410        Err(panic) => {
411            let msg = panic
412                .downcast_ref::<String>()
413                .map(|s| s.as_str())
414                .or_else(|| panic.downcast_ref::<&str>().copied())
415                .unwrap_or("unknown panic");
416            tracing::error!(node_id, "filter panicked: {msg}");
417            Err(SomaError::Execution {
418                node_id: node_id.to_string(),
419                message: format!("filter panicked: {msg}"),
420            })
421        }
422    };
423
424    match result {
425        Ok(output) => {
426            let duration = start.elapsed();
427            let summary = format!("{output}");
428            let vv = ctx.maybe_spill(node_id, output);
429            ctx.set_virtual(node_id, vv);
430            ctx.event_bus.emit(Event::NodeCompleted {
431                run_id: ctx.run_id.clone(),
432                node_id: node_id.to_string(),
433                duration,
434                output_summary: summary,
435            });
436            Ok(())
437        }
438        Err(e) => {
439            tracing::error!(node_id, error = %e, "node execution failed");
440            ctx.event_bus.emit(Event::NodeFailed {
441                run_id: ctx.run_id.clone(),
442                node_id: node_id.to_string(),
443                error: e.to_string(),
444            });
445            Err(e)
446        }
447    }
448}
449
450/// Execute parallel branches concurrently using std::thread::scope.
451///
452/// Each branch gets a snapshot of the context. After all branches complete,
453/// their new outputs are merged back into the main context.
454fn execute_parallel(
455    branches: &[ExecutionPlan],
456    ctx: &mut Context,
457    filters: &FilterLibrary,
458    cache: &dyn CacheStore,
459) -> Result<()> {
460    let snapshot_keys: Arc<std::collections::HashSet<String>> =
461        Arc::new(ctx.store.keys().cloned().collect());
462
463    // Use scoped threads for true parallelism without Send requirements
464    let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
465        let handles: Vec<_> = branches
466            .iter()
467            .map(|branch| {
468                let mut branch_ctx = ctx.snapshot();
469                let keys = snapshot_keys.clone();
470                s.spawn(move || {
471                    execute(branch, &mut branch_ctx, filters, cache)?;
472                    let new_entries: Vec<(String, VirtualValue)> = branch_ctx
473                        .store
474                        .into_iter()
475                        .filter(|(k, _)| !keys.contains(k))
476                        .collect();
477                    Ok(new_entries)
478                })
479            })
480            .collect();
481
482        handles.into_iter().map(|h| h.join().unwrap()).collect()
483    });
484
485    // Merge results and propagate first error
486    for result in results {
487        let entries = result?;
488        for (key, vv) in entries {
489            ctx.set_virtual(key, vv);
490        }
491    }
492
493    Ok(())
494}
495
496/// Resolve a VirtualValue to a concrete Value, loading from DataStore if needed.
497fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
498    match vv {
499        VirtualValue::Materialized { value, .. } => Some(value.clone()),
500        VirtualValue::Cached { key, .. } => {
501            // Try to load from DataStore
502            if let Some(store) = data_store {
503                let data_ref = somatize_core::store::DataRef::Cached {
504                    cache_key: key.clone(),
505                };
506                store.get(&data_ref).ok()
507            } else {
508                None
509            }
510        }
511        _ => None,
512    }
513}
514
515/// Resolve the input for a node from the context store using graph topology.
516/// If a predecessor was spilled to DataStore, loads it back.
517pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
518    let preds = ctx.graph_info.predecessors(node_id);
519
520    let resolve_node = |id: &str| -> Option<Value> {
521        ctx.store
522            .get(id)
523            .and_then(|vv| resolve_value(vv, &ctx.data_store))
524    };
525
526    match preds.len() {
527        0 => ctx
528            .execution_order
529            .last()
530            .and_then(|id| resolve_node(id))
531            .unwrap_or(Value::Empty),
532        1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
533        _ => {
534            let mut merged = serde_json::Map::new();
535            for pred_id in preds {
536                if let Some(val) = resolve_node(pred_id) {
537                    let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
538                    merged.insert(pred_id.clone(), json_val);
539                }
540            }
541            Value::Json(serde_json::Value::Object(merged))
542        }
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use crate::cache::MemoryCache;
550    use somatize_core::cache::CacheKey;
551    use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
552
553    struct DoublerFilter;
554
555    impl Filter for DoublerFilter {
556        fn config_hash(&self) -> CacheKey {
557            CacheKey::from_parts(&[b"Doubler"])
558        }
559        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
560            Ok(Value::Empty)
561        }
562        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
563            match x {
564                Value::Tensor { values, shape } => {
565                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
566                    Ok(Value::tensor(doubled, shape.clone()))
567                }
568                _ => Ok(x.clone()),
569            }
570        }
571        fn meta(&self) -> FilterMeta {
572            FilterMeta {
573                name: "Doubler".into(),
574                kind: FilterKind::Stateless,
575                cacheable: true,
576                differentiable: true,
577                stream_mode: StreamMode::FixedState,
578                distribution: somatize_core::filter::Distribution::Local,
579                input_schema: None,
580                output_schema: None,
581            }
582        }
583
584        fn as_any(&self) -> &dyn std::any::Any {
585            self
586        }
587    }
588
589    struct AdderFilter {
590        amount: f64,
591    }
592
593    impl Filter for AdderFilter {
594        fn config_hash(&self) -> CacheKey {
595            CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
596        }
597        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
598            Ok(Value::Empty)
599        }
600        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
601            match x {
602                Value::Tensor { values, shape } => {
603                    let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
604                    Ok(Value::tensor(added, shape.clone()))
605                }
606                _ => Ok(x.clone()),
607            }
608        }
609        fn meta(&self) -> FilterMeta {
610            FilterMeta {
611                name: "Adder".into(),
612                kind: FilterKind::Stateless,
613                cacheable: true,
614                differentiable: true,
615                stream_mode: StreamMode::FixedState,
616                distribution: somatize_core::filter::Distribution::Local,
617                input_schema: None,
618                output_schema: None,
619            }
620        }
621
622        fn as_any(&self) -> &dyn std::any::Any {
623            self
624        }
625    }
626
627    /// Slow filter that sleeps to verify parallelism.
628    struct SlowFilter {
629        id: String,
630        delay_ms: u64,
631    }
632
633    impl Filter for SlowFilter {
634        fn config_hash(&self) -> CacheKey {
635            CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
636        }
637        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
638            Ok(Value::Empty)
639        }
640        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
641            std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
642            Ok(x.clone())
643        }
644        fn meta(&self) -> FilterMeta {
645            FilterMeta {
646                name: format!("Slow_{}", self.id),
647                kind: FilterKind::Stateless,
648                cacheable: false,
649                differentiable: true,
650                stream_mode: StreamMode::FixedState,
651                distribution: somatize_core::filter::Distribution::Local,
652                input_schema: None,
653                output_schema: None,
654            }
655        }
656
657        fn as_any(&self) -> &dyn std::any::Any {
658            self
659        }
660    }
661
662    fn setup() -> (Arc<EventBus>, MemoryCache) {
663        (Arc::new(EventBus::new(64)), MemoryCache::default())
664    }
665
666    #[test]
667    fn execute_single_node() {
668        let (bus, cache) = setup();
669        let mut ctx = Context::new(bus, "run_1");
670        ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
671        ctx.graph_info
672            .set_predecessors("doubler", vec!["input".into()]);
673
674        let mut filters = FilterLibrary::new();
675        filters.register("doubler", Box::new(DoublerFilter));
676
677        let plan = ExecutionPlan::Execute {
678            node_id: "doubler".into(),
679        };
680
681        execute(&plan, &mut ctx, &filters, &cache).unwrap();
682
683        let result = ctx.get("doubler").unwrap();
684        let (data, _) = result.as_tensor().unwrap();
685        assert_eq!(data, &[2.0, 4.0, 6.0]);
686    }
687
688    #[test]
689    fn execute_sequence_with_graph_info() {
690        let (bus, cache) = setup();
691        let mut ctx = Context::new(bus, "run_1");
692        ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
693
694        let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
695        ctx.graph_info = graph_info;
696
697        let mut filters = FilterLibrary::new();
698        filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
699        filters.register("double", Box::new(DoublerFilter));
700
701        let plan = ExecutionPlan::Sequence(vec![
702            ExecutionPlan::Execute {
703                node_id: "add".into(),
704            },
705            ExecutionPlan::Execute {
706                node_id: "double".into(),
707            },
708        ]);
709
710        execute(&plan, &mut ctx, &filters, &cache).unwrap();
711
712        let result = ctx.get("double").unwrap();
713        let (data, _) = result.as_tensor().unwrap();
714        assert_eq!(data, &[22.0, 24.0]);
715    }
716
717    #[test]
718    fn execute_cached_node() {
719        let (bus, cache) = setup();
720        let key = CacheKey::hash_data(b"cached_result");
721        let cached_value = Value::tensor(vec![99.0], vec![1]);
722        cache.put(&key, &cached_value).unwrap();
723
724        let mut ctx = Context::new(bus, "run_1");
725        let filters = FilterLibrary::new();
726
727        let plan = ExecutionPlan::Cached {
728            node_id: "cached_node".into(),
729            key,
730        };
731
732        execute(&plan, &mut ctx, &filters, &cache).unwrap();
733        assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
734    }
735
736    #[test]
737    fn execute_emits_events() {
738        let bus = Arc::new(EventBus::new(64));
739        let cache = MemoryCache::default();
740        let mut rx = bus.subscribe();
741
742        let mut ctx = Context::new(bus, "run_1");
743        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
744        ctx.graph_info
745            .set_predecessors("double", vec!["input".into()]);
746
747        let mut filters = FilterLibrary::new();
748        filters.register("double", Box::new(DoublerFilter));
749
750        execute(
751            &ExecutionPlan::Execute {
752                node_id: "double".into(),
753            },
754            &mut ctx,
755            &filters,
756            &cache,
757        )
758        .unwrap();
759
760        let e1 = rx.try_recv().unwrap();
761        assert!(matches!(e1, Event::NodeStarted { .. }));
762        let e2 = rx.try_recv().unwrap();
763        assert!(matches!(e2, Event::NodeCompleted { .. }));
764    }
765
766    #[test]
767    fn execute_missing_filter_errors() {
768        let (bus, cache) = setup();
769        let mut ctx = Context::new(bus, "run_1");
770        let filters = FilterLibrary::new();
771
772        let result = execute(
773            &ExecutionPlan::Execute {
774                node_id: "nonexistent".into(),
775            },
776            &mut ctx,
777            &filters,
778            &cache,
779        );
780        assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
781    }
782
783    #[test]
784    fn execute_empty_plan() {
785        let (bus, cache) = setup();
786        let mut ctx = Context::new(bus, "run_1");
787        let filters = FilterLibrary::new();
788        execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
789    }
790
791    #[test]
792    fn execute_parallel_branches_merge_outputs() {
793        let (bus, cache) = setup();
794        let mut ctx = Context::new(bus, "run_1");
795        ctx.set("input", Value::tensor(vec![5.0], vec![1]));
796        ctx.graph_info
797            .set_predecessors("double", vec!["input".into()]);
798        ctx.graph_info.set_predecessors("add", vec!["input".into()]);
799
800        let mut filters = FilterLibrary::new();
801        filters.register("double", Box::new(DoublerFilter));
802        filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
803
804        let plan = ExecutionPlan::Parallel(vec![
805            ExecutionPlan::Execute {
806                node_id: "double".into(),
807            },
808            ExecutionPlan::Execute {
809                node_id: "add".into(),
810            },
811        ]);
812
813        execute(&plan, &mut ctx, &filters, &cache).unwrap();
814
815        let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
816        assert_eq!(double_out, &[10.0]);
817        let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
818        assert_eq!(add_out, &[105.0]);
819    }
820
821    #[test]
822    fn parallel_branches_run_concurrently() {
823        let (bus, cache) = setup();
824        let mut ctx = Context::new(bus, "run_1");
825        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
826        ctx.graph_info
827            .set_predecessors("slow_a", vec!["input".into()]);
828        ctx.graph_info
829            .set_predecessors("slow_b", vec!["input".into()]);
830
831        let mut filters = FilterLibrary::new();
832        filters.register(
833            "slow_a",
834            Box::new(SlowFilter {
835                id: "a".into(),
836                delay_ms: 50,
837            }),
838        );
839        filters.register(
840            "slow_b",
841            Box::new(SlowFilter {
842                id: "b".into(),
843                delay_ms: 50,
844            }),
845        );
846
847        let plan = ExecutionPlan::Parallel(vec![
848            ExecutionPlan::Execute {
849                node_id: "slow_a".into(),
850            },
851            ExecutionPlan::Execute {
852                node_id: "slow_b".into(),
853            },
854        ]);
855
856        let start = Instant::now();
857        execute(&plan, &mut ctx, &filters, &cache).unwrap();
858        let elapsed = start.elapsed();
859
860        // If truly parallel: ~50ms. If sequential: ~100ms.
861        // Use 90ms as threshold to account for overhead.
862        assert!(
863            elapsed.as_millis() < 90,
864            "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
865            elapsed.as_millis()
866        );
867
868        assert!(ctx.get("slow_a").is_some());
869        assert!(ctx.get("slow_b").is_some());
870    }
871
872    #[test]
873    fn resolve_input_single_predecessor() {
874        let bus = Arc::new(EventBus::new(8));
875        let mut ctx = Context::new(bus, "r");
876        ctx.set("A", Value::tensor(vec![42.0], vec![1]));
877        ctx.graph_info.set_predecessors("B", vec!["A".into()]);
878
879        let input = resolve_input("B", &ctx);
880        let (data, _) = input.as_tensor().unwrap();
881        assert_eq!(data, &[42.0]);
882    }
883
884    #[test]
885    fn resolve_input_multiple_predecessors() {
886        let bus = Arc::new(EventBus::new(8));
887        let mut ctx = Context::new(bus, "r");
888        ctx.set("A", Value::tensor(vec![1.0], vec![1]));
889        ctx.set("B", Value::tensor(vec![2.0], vec![1]));
890        ctx.graph_info
891            .set_predecessors("C", vec!["A".into(), "B".into()]);
892
893        let input = resolve_input("C", &ctx);
894        let json = input.as_json().unwrap();
895        assert!(json.get("A").is_some());
896        assert!(json.get("B").is_some());
897    }
898
899    #[test]
900    fn resolve_input_no_predecessors_fallback() {
901        let bus = Arc::new(EventBus::new(8));
902        let mut ctx = Context::new(bus, "r");
903        ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
904
905        let input = resolve_input("root", &ctx);
906        let (data, _) = input.as_tensor().unwrap();
907        assert_eq!(data, &[7.0]);
908    }
909
910    #[test]
911    fn graph_info_from_linear() {
912        let info = GraphInfo::for_linear(&["a", "b", "c"]);
913        assert!(info.predecessors("a").is_empty());
914        assert_eq!(info.predecessors("b"), &["a"]);
915        assert_eq!(info.predecessors("c"), &["b"]);
916    }
917}