Skip to main content

somatize_runtime/
executor.rs

1//! Plan executor — walks [`ExecutionPlan`] trees and runs filter nodes.
2//!
3//! Handles sequential, parallel (scoped threads), cached, remote, loop,
4//! and branch execution. Uses [`GraphInfo`] for topology-aware input resolution.
5
6use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19/// Graph topology information for input resolution.
20///
21/// Maps each node to its predecessor node IDs so the executor knows
22/// where to read inputs from in the context store.
23#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25    /// node_id → list of predecessor node IDs
26    predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Register predecessors for a node.
35    pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36        self.predecessors.insert(node_id.into(), preds);
37    }
38
39    /// Build GraphInfo from a somatize_core::graph::Graph.
40    pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41        let mut info = Self::new();
42        for node in &graph.nodes {
43            let preds: Vec<String> = graph
44                .predecessors(&node.id)
45                .into_iter()
46                .map(|s| s.to_string())
47                .collect();
48            info.set_predecessors(node.id.clone(), preds);
49        }
50        info
51    }
52
53    /// Build GraphInfo for a linear pipeline (each node depends on the previous).
54    pub fn for_linear(node_ids: &[&str]) -> Self {
55        let mut info = Self::new();
56        for (i, &id) in node_ids.iter().enumerate() {
57            let preds = if i > 0 {
58                vec![node_ids[i - 1].to_string()]
59            } else {
60                vec![]
61            };
62            info.set_predecessors(id, preds);
63        }
64        info
65    }
66
67    /// Get predecessors for a node.
68    pub fn predecessors(&self, node_id: &str) -> &[String] {
69        self.predecessors
70            .get(node_id)
71            .map(|v| v.as_slice())
72            .unwrap_or(&[])
73    }
74}
75
76/// Execution context passed to filters during runtime.
77///
78/// Node outputs are stored as [`VirtualValue`]s — they may be materialized
79/// in memory, cached on disk, or deferred (not yet computed). The executor
80/// resolves them on demand when a downstream node needs the data.
81pub struct Context {
82    /// Node outputs as virtual values (may be lazy).
83    pub store: HashMap<String, VirtualValue>,
84    /// Event bus for emitting runtime events.
85    pub event_bus: Arc<EventBus>,
86    /// Current run ID.
87    pub run_id: String,
88    /// Track execution order.
89    pub execution_order: Vec<String>,
90    /// Graph topology for input resolution.
91    pub graph_info: GraphInfo,
92    /// Optional transport for distributed plans.
93    pub transport: Option<Arc<dyn crate::runner::Transport>>,
94    /// Optional data store for persisting intermediate results.
95    pub data_store: Option<Arc<dyn DataStore>>,
96    /// Minimum value size (bytes) to spill to DataStore instead of keeping in memory.
97    /// Default: 0 (disabled — all values stay in memory).
98    pub spill_threshold: usize,
99}
100
101impl Context {
102    pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
103        Self {
104            store: HashMap::new(),
105            event_bus,
106            run_id: run_id.into(),
107            execution_order: Vec::new(),
108            graph_info: GraphInfo::new(),
109            transport: None,
110            data_store: None,
111            spill_threshold: 0,
112        }
113    }
114
115    pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
116        self.graph_info = info;
117        self
118    }
119
120    pub fn with_transport(mut self, transport: Arc<dyn crate::runner::Transport>) -> Self {
121        self.transport = Some(transport);
122        self
123    }
124
125    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
126        self.data_store = Some(store);
127        self
128    }
129
130    /// Set spill threshold: values larger than this (in bytes) are offloaded
131    /// to the DataStore and replaced with a VirtualValue::Cached reference.
132    /// Requires a DataStore to be set via `with_data_store()`.
133    pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
134        self.spill_threshold = bytes;
135        self
136    }
137
138    /// If a DataStore and spill threshold are configured, check if the value
139    /// should be offloaded. Returns VirtualValue (materialized or cached ref).
140    fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
141        if self.spill_threshold > 0
142            && let Some(store) = &self.data_store
143        {
144            let size = value.size() * 8; // approximate bytes (f64 = 8 bytes)
145            if size >= self.spill_threshold {
146                let key = somatize_core::cache::CacheKey::from_parts(&[
147                    self.run_id.as_bytes(),
148                    node_id.as_bytes(),
149                ]);
150                let vv_for_schema = VirtualValue::materialized(value.clone());
151                let schema = vv_for_schema.schema().clone();
152                if let Ok(_data_ref) = store.put(&key, &value) {
153                    tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
154                    return VirtualValue::cached(key, schema);
155                }
156            }
157        }
158        VirtualValue::materialized(value)
159    }
160
161    /// Get the materialized Value for a node, if present and materialized.
162    pub fn get(&self, node_id: &str) -> Option<&Value> {
163        self.store.get(node_id).and_then(|vv| vv.as_value())
164    }
165
166    /// Get the raw VirtualValue for a node.
167    pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
168        self.store.get(node_id)
169    }
170
171    /// Store a materialized value for a node.
172    pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
173        let id = node_id.into();
174        self.execution_order.push(id.clone());
175        self.store.insert(id, VirtualValue::materialized(value));
176    }
177
178    /// Store a virtual value (which may be deferred or cached).
179    pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
180        let id = node_id.into();
181        self.execution_order.push(id.clone());
182        self.store.insert(id, vv);
183    }
184
185    fn snapshot(&self) -> Self {
186        Self {
187            store: self.store.clone(),
188            event_bus: self.event_bus.clone(),
189            run_id: self.run_id.clone(),
190            execution_order: self.execution_order.clone(),
191            graph_info: self.graph_info.clone(),
192            transport: self.transport.clone(),
193            data_store: self.data_store.clone(),
194            spill_threshold: self.spill_threshold,
195        }
196    }
197}
198
199/// Execute a compiled plan synchronously.
200/// For parallel branches, uses the async executor under the hood.
201/// Contract for executing a plan.
202pub trait Executable {
203    fn execute(
204        &self,
205        ctx: &mut Context,
206        filters: &FilterLibrary,
207        cache: &dyn CacheStore,
208    ) -> Result<()>;
209}
210
211impl Executable for ExecutionPlan {
212    fn execute(
213        &self,
214        ctx: &mut Context,
215        filters: &FilterLibrary,
216        cache: &dyn CacheStore,
217    ) -> Result<()> {
218        match self {
219            ExecutionPlan::Empty => Ok(()),
220
221            ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
222
223            ExecutionPlan::Cached { node_id, key } => {
224                let start = Instant::now();
225                let value = cache.get(key)?.ok_or_else(|| {
226                    SomaError::Cache(format!(
227                        "expected cached value for node `{node_id}` not found"
228                    ))
229                })?;
230                ctx.set(node_id.clone(), value);
231                ctx.event_bus.emit(Event::NodeCacheHit {
232                    run_id: ctx.run_id.clone(),
233                    node_id: node_id.clone(),
234                    key: key.clone(),
235                    tier: somatize_core::cache::CacheTier::Memory,
236                    load_time: start.elapsed(),
237                });
238                Ok(())
239            }
240
241            ExecutionPlan::Sequence(steps) => {
242                for step in steps {
243                    step.execute(ctx, filters, cache)?;
244                }
245                Ok(())
246            }
247
248            ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
249
250            ExecutionPlan::Loop {
251                node_id,
252                body,
253                max_iterations,
254            } => {
255                let max = max_iterations.unwrap_or(100);
256                for i in 0..max {
257                    body.execute(ctx, filters, cache)?;
258
259                    // Check termination: if the last executed node produced a Value
260                    // that indicates "done" (true, "done", "stop", or empty), break.
261                    let should_stop = ctx
262                        .execution_order
263                        .last()
264                        .and_then(|last_id| ctx.get(last_id))
265                        .map(|v| match v {
266                            Value::Json(j) => {
267                                j.as_bool() == Some(true)
268                                    || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
269                                    || j.get("done").and_then(|d| d.as_bool()) == Some(true)
270                            }
271                            Value::Empty => true,
272                            _ => false,
273                        })
274                        .unwrap_or(false);
275
276                    if should_stop {
277                        ctx.event_bus.emit(Event::NodeCompleted {
278                            run_id: ctx.run_id.clone(),
279                            node_id: node_id.clone(),
280                            duration: std::time::Duration::ZERO,
281                            output_summary: format!("Loop terminated at iteration {}", i + 1),
282                        });
283                        break;
284                    }
285                }
286                Ok(())
287            }
288
289            ExecutionPlan::Branch { node_id, arms } => {
290                // Execute the branch node first (it produces the condition value)
291                execute_node(node_id, ctx, filters, cache)?;
292
293                // Get the condition result
294                let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
295
296                // Match against arm labels
297                let selected_arm = match &condition {
298                    Value::Json(j) => {
299                        // Try matching by string value, bool, or "branch" field
300                        let selector = j
301                            .as_str()
302                            .map(String::from)
303                            .or_else(|| j.as_bool().map(|b| b.to_string()))
304                            .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
305                            .unwrap_or_else(|| "true".to_string());
306
307                        arms.iter()
308                            .find(|(label, _)| label == &selector)
309                            .or_else(|| {
310                                arms.iter()
311                                    .find(|(label, _)| label == "default" || label == "else")
312                            })
313                            .or_else(|| arms.first())
314                    }
315                    _ => arms.first(),
316                };
317
318                if let Some((label, plan)) = selected_arm {
319                    ctx.event_bus.emit(Event::NodeCompleted {
320                        run_id: ctx.run_id.clone(),
321                        node_id: node_id.clone(),
322                        duration: std::time::Duration::ZERO,
323                        output_summary: format!("Branch selected: {label}"),
324                    });
325                    plan.execute(ctx, filters, cache)?;
326                }
327                Ok(())
328            }
329
330            ExecutionPlan::Remote {
331                node_id,
332                target: _,
333                plan,
334            } => {
335                if let Some(transport) = &ctx.transport {
336                    // Gather input from predecessors
337                    let input = ctx
338                        .graph_info
339                        .predecessors(node_id)
340                        .first()
341                        .and_then(|pred| ctx.get(pred));
342
343                    let result = transport.execute_node(node_id, input)?;
344                    ctx.set(node_id.clone(), result);
345                    Ok(())
346                } else {
347                    // No transport — fall back to local execution
348                    plan.execute(ctx, filters, cache)
349                }
350            }
351
352            ExecutionPlan::Composite { node_ids } => {
353                // Sequential fallback — execute each node in order.
354                // A future Python-aware executor will pass tensors directly.
355                for nid in node_ids {
356                    execute_node(nid, ctx, filters, cache)?;
357                }
358                Ok(())
359            }
360
361            ExecutionPlan::Stream {
362                node_ids,
363                chunk_size,
364            } => execute_stream(node_ids, *chunk_size, ctx, filters, cache),
365
366            _ => {
367                tracing::warn!("Unhandled ExecutionPlan variant");
368                Ok(())
369            }
370        }
371    }
372}
373
374/// Execute a plan (convenience function, delegates to `Executable` trait).
375pub fn execute(
376    plan: &ExecutionPlan,
377    ctx: &mut Context,
378    filters: &FilterLibrary,
379    cache: &dyn CacheStore,
380) -> Result<()> {
381    plan.execute(ctx, filters, cache)
382}
383
384/// Execute a single filter node.
385fn execute_node(
386    node_id: &str,
387    ctx: &mut Context,
388    filters: &FilterLibrary,
389    _cache: &dyn CacheStore,
390) -> Result<()> {
391    let start = Instant::now();
392
393    let filter = filters
394        .get(node_id)
395        .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
396
397    ctx.event_bus.emit(Event::NodeStarted {
398        run_id: ctx.run_id.clone(),
399        node_id: node_id.to_string(),
400        kind: filter.meta().kind,
401    });
402
403    let _span = tracing::info_span!("execute_node", %node_id).entered();
404
405    let input = resolve_input(node_id, ctx);
406    // Borrow state via Arc — cloning the inner Value here would deep-copy
407    // potentially huge tensors (encoder outputs, model weights) on every
408    // forward call. Arc::clone is a cheap atomic increment.
409    let state = filters.get_state(node_id);
410    let state_ref: &Value = state.as_deref().unwrap_or(&Value::Empty);
411
412    // catch_unwind: a panic in a user filter must not crash the process
413    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
414        filter.forward(&input, state_ref)
415    }));
416
417    let result = match result {
418        Ok(inner) => inner,
419        Err(panic) => {
420            let msg = panic
421                .downcast_ref::<String>()
422                .map(|s| s.as_str())
423                .or_else(|| panic.downcast_ref::<&str>().copied())
424                .unwrap_or("unknown panic");
425            tracing::error!(node_id, "filter panicked: {msg}");
426            Err(SomaError::Execution {
427                node_id: node_id.to_string(),
428                message: format!("filter panicked: {msg}"),
429            })
430        }
431    };
432
433    match result {
434        Ok(output) => {
435            let duration = start.elapsed();
436            let summary = format!("{output}");
437            let vv = ctx.maybe_spill(node_id, output);
438            ctx.set_virtual(node_id, vv);
439            ctx.event_bus.emit(Event::NodeCompleted {
440                run_id: ctx.run_id.clone(),
441                node_id: node_id.to_string(),
442                duration,
443                output_summary: summary,
444            });
445            Ok(())
446        }
447        Err(e) => {
448            tracing::error!(node_id, error = %e, "node execution failed");
449            ctx.event_bus.emit(Event::NodeFailed {
450                run_id: ctx.run_id.clone(),
451                node_id: node_id.to_string(),
452                error: e.to_string(),
453            });
454            Err(e)
455        }
456    }
457}
458
459/// Execute parallel branches concurrently using std::thread::scope.
460///
461/// Each branch gets a snapshot of the context. After all branches complete,
462/// their new outputs are merged back into the main context.
463fn execute_parallel(
464    branches: &[ExecutionPlan],
465    ctx: &mut Context,
466    filters: &FilterLibrary,
467    cache: &dyn CacheStore,
468) -> Result<()> {
469    let snapshot_keys: Arc<std::collections::HashSet<String>> =
470        Arc::new(ctx.store.keys().cloned().collect());
471
472    // Use scoped threads for true parallelism without Send requirements
473    let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
474        let handles: Vec<_> = branches
475            .iter()
476            .map(|branch| {
477                let mut branch_ctx = ctx.snapshot();
478                let keys = snapshot_keys.clone();
479                s.spawn(move || {
480                    execute(branch, &mut branch_ctx, filters, cache)?;
481                    let new_entries: Vec<(String, VirtualValue)> = branch_ctx
482                        .store
483                        .into_iter()
484                        .filter(|(k, _)| !keys.contains(k))
485                        .collect();
486                    Ok(new_entries)
487                })
488            })
489            .collect();
490
491        handles.into_iter().map(|h| h.join().unwrap()).collect()
492    });
493
494    // Merge results and propagate first error
495    for result in results {
496        let entries = result?;
497        for (key, vv) in entries {
498            ctx.set_virtual(key, vv);
499        }
500    }
501
502    Ok(())
503}
504
505/// Resolve a VirtualValue to a concrete Value, loading from DataStore if needed.
506fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
507    match vv {
508        VirtualValue::Materialized { value, .. } => Some(value.clone()),
509        VirtualValue::Cached { key, .. } => {
510            // Try to load from DataStore
511            if let Some(store) = data_store {
512                let data_ref = somatize_core::store::DataRef::Cached {
513                    cache_key: key.clone(),
514                };
515                store.get(&data_ref).ok()
516            } else {
517                None
518            }
519        }
520        _ => None,
521    }
522}
523
524/// Resolve the input for a node from the context store using graph topology.
525/// If a predecessor was spilled to DataStore, loads it back.
526pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
527    let preds = ctx.graph_info.predecessors(node_id);
528
529    let resolve_node = |id: &str| -> Option<Value> {
530        ctx.store
531            .get(id)
532            .and_then(|vv| resolve_value(vv, &ctx.data_store))
533    };
534
535    match preds.len() {
536        0 => ctx
537            .execution_order
538            .last()
539            .and_then(|id| resolve_node(id))
540            .unwrap_or(Value::Empty),
541        1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
542        _ => {
543            let mut merged = serde_json::Map::new();
544            for pred_id in preds {
545                if let Some(val) = resolve_node(pred_id) {
546                    let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
547                    merged.insert(pred_id.clone(), json_val);
548                }
549            }
550            Value::Json(serde_json::Value::Object(merged))
551        }
552    }
553}
554
555/// Execute a stream plan: chunk input, process through StreamExecutor, concatenate.
556fn execute_stream(
557    node_ids: &[String],
558    chunk_size: usize,
559    ctx: &mut Context,
560    filters: &FilterLibrary,
561    cache: &dyn CacheStore,
562) -> Result<()> {
563    use crate::executors::{FittedFilter, StreamExecutor, materialize_buffer};
564
565    let start = Instant::now();
566
567    // Build FittedFilter list from the library in plan order.
568    let fitted: Vec<FittedFilter> = node_ids
569        .iter()
570        .map(|nid| {
571            let filter = filters
572                .get(nid)
573                .ok_or_else(|| SomaError::NodeNotFound(nid.clone()))?;
574            let state = filters
575                .get_state(nid)
576                .map(|arc| (*arc).clone())
577                .unwrap_or(Value::Empty);
578            Ok(FittedFilter {
579                name: nid.clone(),
580                filter,
581                state,
582            })
583        })
584        .collect::<Result<_>>()?;
585
586    // Resolve input from the first node's predecessors.
587    let first_id = node_ids
588        .first()
589        .ok_or_else(|| SomaError::Other("stream plan has no nodes".into()))?;
590    let input = resolve_input(first_id, ctx);
591
592    // Chunk the input along the first tensor dimension.
593    let chunks = chunk_value(&input, chunk_size);
594
595    // Process chunks through the stream executor.
596    let mut executor = StreamExecutor::new(fitted);
597    if let Some(c) = cache_as_arc(cache) {
598        executor = executor.with_cache(c);
599    }
600
601    let last_id = node_ids.last().unwrap();
602
603    let mut outputs: Vec<Value> = Vec::new();
604    for (i, chunk) in chunks.into_iter().enumerate() {
605        ctx.event_bus.emit(Event::NodeStarted {
606            run_id: ctx.run_id.clone(),
607            node_id: format!("{last_id}#chunk_{i}"),
608            kind: somatize_core::filter::FilterKind::Stateless,
609        });
610        if let Some(output) = executor.process_chunk(chunk)? {
611            outputs.push(output);
612        }
613    }
614
615    // Flush barrier filters.
616    if let Some(flushed) = executor.flush()? {
617        outputs.push(flushed);
618    }
619
620    // Concatenate results into a single value.
621    let result = if outputs.len() == 1 {
622        outputs.into_iter().next().unwrap()
623    } else if outputs.is_empty() {
624        Value::Empty
625    } else {
626        materialize_buffer(&outputs)?
627    };
628
629    let duration = start.elapsed();
630    ctx.set(last_id.clone(), result);
631    ctx.event_bus.emit(Event::NodeCompleted {
632        run_id: ctx.run_id.clone(),
633        node_id: last_id.clone(),
634        duration,
635        output_summary: format!(
636            "stream: {} chunks through {} filters",
637            executor.chunks_processed(),
638            node_ids.len()
639        ),
640    });
641
642    Ok(())
643}
644
645/// Split a Value::Tensor along the first dimension into chunks.
646fn chunk_value(x: &Value, chunk_size: usize) -> Vec<Value> {
647    match x {
648        Value::Tensor { values, shape } if !values.is_empty() && chunk_size > 0 => {
649            let row_size = if shape.len() > 1 {
650                shape[1..].iter().product()
651            } else {
652                1
653            };
654            let n_rows = shape[0];
655            let mut chunks = Vec::new();
656            for start in (0..n_rows).step_by(chunk_size) {
657                let end = (start + chunk_size).min(n_rows);
658                let flat_start = start * row_size;
659                let flat_end = end * row_size;
660                let chunk_vals = values[flat_start..flat_end].to_vec();
661                let mut chunk_shape = shape.clone();
662                chunk_shape[0] = end - start;
663                chunks.push(Value::tensor(chunk_vals, chunk_shape));
664            }
665            chunks
666        }
667        _ => vec![x.clone()],
668    }
669}
670
671/// Try to wrap the cache reference as an Arc for StreamExecutor.
672/// This is safe because the cache outlives the executor within execute().
673fn cache_as_arc(_cache: &dyn CacheStore) -> Option<Arc<dyn CacheStore>> {
674    // StreamExecutor requires Arc, but we only have a reference.
675    // For local execution we skip the cache in streaming mode — the per-chunk
676    // cache is an optimization, not a correctness requirement.
677    // TODO: pass cache as Option<&dyn CacheStore> to StreamExecutor.
678    None
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use crate::cache::MemoryCache;
685    use somatize_core::cache::CacheKey;
686    use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
687
688    struct DoublerFilter;
689
690    impl Filter for DoublerFilter {
691        fn config_hash(&self) -> CacheKey {
692            CacheKey::from_parts(&[b"Doubler"])
693        }
694        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
695            Ok(Value::Empty)
696        }
697        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
698            match x {
699                Value::Tensor { values, shape } => {
700                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
701                    Ok(Value::tensor(doubled, shape.clone()))
702                }
703                _ => Ok(x.clone()),
704            }
705        }
706        fn meta(&self) -> FilterMeta {
707            FilterMeta {
708                name: "Doubler".into(),
709                kind: FilterKind::Stateless,
710                cacheable: true,
711                differentiable: true,
712                stream_mode: StreamMode::FixedState,
713                distribution: somatize_core::filter::Distribution::Local,
714                input_schema: None,
715                output_schema: None,
716            }
717        }
718
719        fn as_any(&self) -> &dyn std::any::Any {
720            self
721        }
722    }
723
724    struct AdderFilter {
725        amount: f64,
726    }
727
728    impl Filter for AdderFilter {
729        fn config_hash(&self) -> CacheKey {
730            CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
731        }
732        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
733            Ok(Value::Empty)
734        }
735        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
736            match x {
737                Value::Tensor { values, shape } => {
738                    let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
739                    Ok(Value::tensor(added, shape.clone()))
740                }
741                _ => Ok(x.clone()),
742            }
743        }
744        fn meta(&self) -> FilterMeta {
745            FilterMeta {
746                name: "Adder".into(),
747                kind: FilterKind::Stateless,
748                cacheable: true,
749                differentiable: true,
750                stream_mode: StreamMode::FixedState,
751                distribution: somatize_core::filter::Distribution::Local,
752                input_schema: None,
753                output_schema: None,
754            }
755        }
756
757        fn as_any(&self) -> &dyn std::any::Any {
758            self
759        }
760    }
761
762    /// Slow filter that sleeps to verify parallelism.
763    struct SlowFilter {
764        id: String,
765        delay_ms: u64,
766    }
767
768    impl Filter for SlowFilter {
769        fn config_hash(&self) -> CacheKey {
770            CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
771        }
772        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
773            Ok(Value::Empty)
774        }
775        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
776            std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
777            Ok(x.clone())
778        }
779        fn meta(&self) -> FilterMeta {
780            FilterMeta {
781                name: format!("Slow_{}", self.id),
782                kind: FilterKind::Stateless,
783                cacheable: false,
784                differentiable: true,
785                stream_mode: StreamMode::FixedState,
786                distribution: somatize_core::filter::Distribution::Local,
787                input_schema: None,
788                output_schema: None,
789            }
790        }
791
792        fn as_any(&self) -> &dyn std::any::Any {
793            self
794        }
795    }
796
797    fn setup() -> (Arc<EventBus>, MemoryCache) {
798        (Arc::new(EventBus::new(64)), MemoryCache::default())
799    }
800
801    #[test]
802    fn execute_single_node() {
803        let (bus, cache) = setup();
804        let mut ctx = Context::new(bus, "run_1");
805        ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
806        ctx.graph_info
807            .set_predecessors("doubler", vec!["input".into()]);
808
809        let mut filters = FilterLibrary::new();
810        filters.register("doubler", Box::new(DoublerFilter));
811
812        let plan = ExecutionPlan::Execute {
813            node_id: "doubler".into(),
814        };
815
816        execute(&plan, &mut ctx, &filters, &cache).unwrap();
817
818        let result = ctx.get("doubler").unwrap();
819        let (data, _) = result.as_tensor().unwrap();
820        assert_eq!(data, &[2.0, 4.0, 6.0]);
821    }
822
823    #[test]
824    fn execute_sequence_with_graph_info() {
825        let (bus, cache) = setup();
826        let mut ctx = Context::new(bus, "run_1");
827        ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
828
829        let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
830        ctx.graph_info = graph_info;
831
832        let mut filters = FilterLibrary::new();
833        filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
834        filters.register("double", Box::new(DoublerFilter));
835
836        let plan = ExecutionPlan::Sequence(vec![
837            ExecutionPlan::Execute {
838                node_id: "add".into(),
839            },
840            ExecutionPlan::Execute {
841                node_id: "double".into(),
842            },
843        ]);
844
845        execute(&plan, &mut ctx, &filters, &cache).unwrap();
846
847        let result = ctx.get("double").unwrap();
848        let (data, _) = result.as_tensor().unwrap();
849        assert_eq!(data, &[22.0, 24.0]);
850    }
851
852    #[test]
853    fn execute_cached_node() {
854        let (bus, cache) = setup();
855        let key = CacheKey::hash_data(b"cached_result");
856        let cached_value = Value::tensor(vec![99.0], vec![1]);
857        cache.put(&key, &cached_value).unwrap();
858
859        let mut ctx = Context::new(bus, "run_1");
860        let filters = FilterLibrary::new();
861
862        let plan = ExecutionPlan::Cached {
863            node_id: "cached_node".into(),
864            key,
865        };
866
867        execute(&plan, &mut ctx, &filters, &cache).unwrap();
868        assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
869    }
870
871    #[test]
872    fn execute_emits_events() {
873        let bus = Arc::new(EventBus::new(64));
874        let cache = MemoryCache::default();
875        let mut rx = bus.subscribe();
876
877        let mut ctx = Context::new(bus, "run_1");
878        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
879        ctx.graph_info
880            .set_predecessors("double", vec!["input".into()]);
881
882        let mut filters = FilterLibrary::new();
883        filters.register("double", Box::new(DoublerFilter));
884
885        execute(
886            &ExecutionPlan::Execute {
887                node_id: "double".into(),
888            },
889            &mut ctx,
890            &filters,
891            &cache,
892        )
893        .unwrap();
894
895        let e1 = rx.try_recv().unwrap();
896        assert!(matches!(e1, Event::NodeStarted { .. }));
897        let e2 = rx.try_recv().unwrap();
898        assert!(matches!(e2, Event::NodeCompleted { .. }));
899    }
900
901    #[test]
902    fn execute_missing_filter_errors() {
903        let (bus, cache) = setup();
904        let mut ctx = Context::new(bus, "run_1");
905        let filters = FilterLibrary::new();
906
907        let result = execute(
908            &ExecutionPlan::Execute {
909                node_id: "nonexistent".into(),
910            },
911            &mut ctx,
912            &filters,
913            &cache,
914        );
915        assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
916    }
917
918    #[test]
919    fn execute_empty_plan() {
920        let (bus, cache) = setup();
921        let mut ctx = Context::new(bus, "run_1");
922        let filters = FilterLibrary::new();
923        execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
924    }
925
926    #[test]
927    fn execute_parallel_branches_merge_outputs() {
928        let (bus, cache) = setup();
929        let mut ctx = Context::new(bus, "run_1");
930        ctx.set("input", Value::tensor(vec![5.0], vec![1]));
931        ctx.graph_info
932            .set_predecessors("double", vec!["input".into()]);
933        ctx.graph_info.set_predecessors("add", vec!["input".into()]);
934
935        let mut filters = FilterLibrary::new();
936        filters.register("double", Box::new(DoublerFilter));
937        filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
938
939        let plan = ExecutionPlan::Parallel(vec![
940            ExecutionPlan::Execute {
941                node_id: "double".into(),
942            },
943            ExecutionPlan::Execute {
944                node_id: "add".into(),
945            },
946        ]);
947
948        execute(&plan, &mut ctx, &filters, &cache).unwrap();
949
950        let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
951        assert_eq!(double_out, &[10.0]);
952        let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
953        assert_eq!(add_out, &[105.0]);
954    }
955
956    #[test]
957    fn parallel_branches_run_concurrently() {
958        let (bus, cache) = setup();
959        let mut ctx = Context::new(bus, "run_1");
960        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
961        ctx.graph_info
962            .set_predecessors("slow_a", vec!["input".into()]);
963        ctx.graph_info
964            .set_predecessors("slow_b", vec!["input".into()]);
965
966        let mut filters = FilterLibrary::new();
967        filters.register(
968            "slow_a",
969            Box::new(SlowFilter {
970                id: "a".into(),
971                delay_ms: 50,
972            }),
973        );
974        filters.register(
975            "slow_b",
976            Box::new(SlowFilter {
977                id: "b".into(),
978                delay_ms: 50,
979            }),
980        );
981
982        let plan = ExecutionPlan::Parallel(vec![
983            ExecutionPlan::Execute {
984                node_id: "slow_a".into(),
985            },
986            ExecutionPlan::Execute {
987                node_id: "slow_b".into(),
988            },
989        ]);
990
991        let start = Instant::now();
992        execute(&plan, &mut ctx, &filters, &cache).unwrap();
993        let elapsed = start.elapsed();
994
995        // If truly parallel: ~50ms. If sequential: ~100ms.
996        // Use 90ms as threshold to account for overhead.
997        assert!(
998            elapsed.as_millis() < 90,
999            "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
1000            elapsed.as_millis()
1001        );
1002
1003        assert!(ctx.get("slow_a").is_some());
1004        assert!(ctx.get("slow_b").is_some());
1005    }
1006
1007    #[test]
1008    fn resolve_input_single_predecessor() {
1009        let bus = Arc::new(EventBus::new(8));
1010        let mut ctx = Context::new(bus, "r");
1011        ctx.set("A", Value::tensor(vec![42.0], vec![1]));
1012        ctx.graph_info.set_predecessors("B", vec!["A".into()]);
1013
1014        let input = resolve_input("B", &ctx);
1015        let (data, _) = input.as_tensor().unwrap();
1016        assert_eq!(data, &[42.0]);
1017    }
1018
1019    #[test]
1020    fn resolve_input_multiple_predecessors() {
1021        let bus = Arc::new(EventBus::new(8));
1022        let mut ctx = Context::new(bus, "r");
1023        ctx.set("A", Value::tensor(vec![1.0], vec![1]));
1024        ctx.set("B", Value::tensor(vec![2.0], vec![1]));
1025        ctx.graph_info
1026            .set_predecessors("C", vec!["A".into(), "B".into()]);
1027
1028        let input = resolve_input("C", &ctx);
1029        let json = input.as_json().unwrap();
1030        assert!(json.get("A").is_some());
1031        assert!(json.get("B").is_some());
1032    }
1033
1034    #[test]
1035    fn resolve_input_no_predecessors_fallback() {
1036        let bus = Arc::new(EventBus::new(8));
1037        let mut ctx = Context::new(bus, "r");
1038        ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
1039
1040        let input = resolve_input("root", &ctx);
1041        let (data, _) = input.as_tensor().unwrap();
1042        assert_eq!(data, &[7.0]);
1043    }
1044
1045    #[test]
1046    fn graph_info_from_linear() {
1047        let info = GraphInfo::for_linear(&["a", "b", "c"]);
1048        assert!(info.predecessors("a").is_empty());
1049        assert_eq!(info.predecessors("b"), &["a"]);
1050        assert_eq!(info.predecessors("c"), &["b"]);
1051    }
1052
1053    #[test]
1054    fn execute_stream_chunks_input() {
1055        let (bus, cache) = setup();
1056        let mut ctx = Context::new(bus, "run_stream");
1057        // 6-element input, chunk_size=2 → 3 chunks
1058        ctx.set(
1059            "__input__",
1060            Value::tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]),
1061        );
1062        ctx.graph_info
1063            .set_predecessors("double", vec!["__input__".into()]);
1064
1065        let mut filters = FilterLibrary::new();
1066        filters.register("double", Box::new(DoublerFilter));
1067
1068        let plan = ExecutionPlan::Stream {
1069            node_ids: vec!["double".into()],
1070            chunk_size: 2,
1071        };
1072
1073        execute(&plan, &mut ctx, &filters, &cache).unwrap();
1074
1075        let result = ctx.get("double").unwrap();
1076        let (data, shape) = result.as_tensor().unwrap();
1077        assert_eq!(data, &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
1078        assert_eq!(shape, &[6]);
1079    }
1080
1081    #[test]
1082    fn execute_stream_chain() {
1083        let (bus, cache) = setup();
1084        let mut ctx = Context::new(bus, "run_stream_chain");
1085        ctx.set(
1086            "__input__",
1087            Value::tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]),
1088        );
1089        ctx.graph_info
1090            .set_predecessors("double", vec!["__input__".into()]);
1091        ctx.graph_info
1092            .set_predecessors("add", vec!["double".into()]);
1093
1094        let mut filters = FilterLibrary::new();
1095        filters.register("double", Box::new(DoublerFilter));
1096        filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
1097
1098        let plan = ExecutionPlan::Stream {
1099            node_ids: vec!["double".into(), "add".into()],
1100            chunk_size: 2,
1101        };
1102
1103        execute(&plan, &mut ctx, &filters, &cache).unwrap();
1104
1105        // double → add: [1,2,3,4] → [2,4,6,8] → [12,14,16,18]
1106        let result = ctx.get("add").unwrap();
1107        let (data, shape) = result.as_tensor().unwrap();
1108        assert_eq!(data, &[12.0, 14.0, 16.0, 18.0]);
1109        assert_eq!(shape, &[4]);
1110    }
1111
1112    #[test]
1113    fn execute_stream_single_chunk() {
1114        let (bus, cache) = setup();
1115        let mut ctx = Context::new(bus, "run_stream_single");
1116        ctx.set("__input__", Value::tensor(vec![5.0, 10.0], vec![2]));
1117        ctx.graph_info
1118            .set_predecessors("double", vec!["__input__".into()]);
1119
1120        let mut filters = FilterLibrary::new();
1121        filters.register("double", Box::new(DoublerFilter));
1122
1123        // chunk_size larger than input → single chunk
1124        let plan = ExecutionPlan::Stream {
1125            node_ids: vec!["double".into()],
1126            chunk_size: 1000,
1127        };
1128
1129        execute(&plan, &mut ctx, &filters, &cache).unwrap();
1130
1131        let result = ctx.get("double").unwrap();
1132        let (data, _) = result.as_tensor().unwrap();
1133        assert_eq!(data, &[10.0, 20.0]);
1134    }
1135}