Skip to main content

somatize_runtime/
executor.rs

1//! Plan executor — walks [`ExecutionPlan`] trees and runs filter nodes.
2//!
3//! Handles sequential, parallel (scoped threads), cached, remote, loop,
4//! and branch execution. Uses [`GraphInfo`] for topology-aware input resolution.
5
6use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19/// Graph topology information for input resolution.
20///
21/// Maps each node to its predecessor node IDs so the executor knows
22/// where to read inputs from in the context store.
23#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25    /// node_id → list of predecessor node IDs
26    predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Register predecessors for a node.
35    pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36        self.predecessors.insert(node_id.into(), preds);
37    }
38
39    /// Build GraphInfo from a somatize_core::graph::Graph.
40    pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41        let mut info = Self::new();
42        for node in &graph.nodes {
43            let preds: Vec<String> = graph
44                .predecessors(&node.id)
45                .into_iter()
46                .map(|s| s.to_string())
47                .collect();
48            info.set_predecessors(node.id.clone(), preds);
49        }
50        info
51    }
52
53    /// Build GraphInfo for a linear pipeline (each node depends on the previous).
54    pub fn for_linear(node_ids: &[&str]) -> Self {
55        let mut info = Self::new();
56        for (i, &id) in node_ids.iter().enumerate() {
57            let preds = if i > 0 {
58                vec![node_ids[i - 1].to_string()]
59            } else {
60                vec![]
61            };
62            info.set_predecessors(id, preds);
63        }
64        info
65    }
66
67    /// Get predecessors for a node.
68    pub fn predecessors(&self, node_id: &str) -> &[String] {
69        self.predecessors
70            .get(node_id)
71            .map(|v| v.as_slice())
72            .unwrap_or(&[])
73    }
74}
75
76/// Execution context passed to filters during runtime.
77///
78/// Node outputs are stored as [`VirtualValue`]s — they may be materialized
79/// in memory, cached on disk, or deferred (not yet computed). The executor
80/// resolves them on demand when a downstream node needs the data.
81pub struct Context {
82    /// Node outputs as virtual values (may be lazy).
83    pub store: HashMap<String, VirtualValue>,
84    /// Event bus for emitting runtime events.
85    pub event_bus: Arc<EventBus>,
86    /// Current run ID.
87    pub run_id: String,
88    /// Track execution order.
89    pub execution_order: Vec<String>,
90    /// Graph topology for input resolution.
91    pub graph_info: GraphInfo,
92    /// Optional transport for distributed plans.
93    pub transport: Option<Arc<dyn crate::runner::Transport>>,
94    /// Optional data store for persisting intermediate results.
95    pub data_store: Option<Arc<dyn DataStore>>,
96    /// Minimum value size (bytes) to spill to DataStore instead of keeping in memory.
97    /// Default: 0 (disabled — all values stay in memory).
98    pub spill_threshold: usize,
99}
100
101impl Context {
102    pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
103        Self {
104            store: HashMap::new(),
105            event_bus,
106            run_id: run_id.into(),
107            execution_order: Vec::new(),
108            graph_info: GraphInfo::new(),
109            transport: None,
110            data_store: None,
111            spill_threshold: 0,
112        }
113    }
114
115    pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
116        self.graph_info = info;
117        self
118    }
119
120    pub fn with_transport(mut self, transport: Arc<dyn crate::runner::Transport>) -> Self {
121        self.transport = Some(transport);
122        self
123    }
124
125    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
126        self.data_store = Some(store);
127        self
128    }
129
130    /// Set spill threshold: values larger than this (in bytes) are offloaded
131    /// to the DataStore and replaced with a VirtualValue::Cached reference.
132    /// Requires a DataStore to be set via `with_data_store()`.
133    pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
134        self.spill_threshold = bytes;
135        self
136    }
137
138    /// If a DataStore and spill threshold are configured, check if the value
139    /// should be offloaded. Returns VirtualValue (materialized or cached ref).
140    fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
141        if self.spill_threshold > 0
142            && let Some(store) = &self.data_store
143        {
144            let size = value.size() * 8; // approximate bytes (f64 = 8 bytes)
145            if size >= self.spill_threshold {
146                let key = somatize_core::cache::CacheKey::from_parts(&[
147                    self.run_id.as_bytes(),
148                    node_id.as_bytes(),
149                ]);
150                let vv_for_schema = VirtualValue::materialized(value.clone());
151                let schema = vv_for_schema.schema().clone();
152                if let Ok(_data_ref) = store.put(&key, &value) {
153                    tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
154                    return VirtualValue::cached(key, schema);
155                }
156            }
157        }
158        VirtualValue::materialized(value)
159    }
160
161    /// Get the materialized Value for a node, if present and materialized.
162    pub fn get(&self, node_id: &str) -> Option<&Value> {
163        self.store.get(node_id).and_then(|vv| vv.as_value())
164    }
165
166    /// Get the raw VirtualValue for a node.
167    pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
168        self.store.get(node_id)
169    }
170
171    /// Store a materialized value for a node.
172    pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
173        let id = node_id.into();
174        self.execution_order.push(id.clone());
175        self.store.insert(id, VirtualValue::materialized(value));
176    }
177
178    /// Store a virtual value (which may be deferred or cached).
179    pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
180        let id = node_id.into();
181        self.execution_order.push(id.clone());
182        self.store.insert(id, vv);
183    }
184
185    fn snapshot(&self) -> Self {
186        Self {
187            store: self.store.clone(),
188            event_bus: self.event_bus.clone(),
189            run_id: self.run_id.clone(),
190            execution_order: self.execution_order.clone(),
191            graph_info: self.graph_info.clone(),
192            transport: self.transport.clone(),
193            data_store: self.data_store.clone(),
194            spill_threshold: self.spill_threshold,
195        }
196    }
197}
198
199/// Execute a compiled plan synchronously.
200/// For parallel branches, uses the async executor under the hood.
201/// Contract for executing a plan.
202pub trait Executable {
203    fn execute(
204        &self,
205        ctx: &mut Context,
206        filters: &FilterLibrary,
207        cache: &dyn CacheStore,
208    ) -> Result<()>;
209}
210
211impl Executable for ExecutionPlan {
212    fn execute(
213        &self,
214        ctx: &mut Context,
215        filters: &FilterLibrary,
216        cache: &dyn CacheStore,
217    ) -> Result<()> {
218        match self {
219            ExecutionPlan::Empty => Ok(()),
220
221            ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
222
223            ExecutionPlan::Cached { node_id, key } => {
224                let start = Instant::now();
225                let value = cache.get(key)?.ok_or_else(|| {
226                    SomaError::Cache(format!(
227                        "expected cached value for node `{node_id}` not found"
228                    ))
229                })?;
230                ctx.set(node_id.clone(), value);
231                ctx.event_bus.emit(Event::NodeCacheHit {
232                    run_id: ctx.run_id.clone(),
233                    node_id: node_id.clone(),
234                    key: key.clone(),
235                    tier: somatize_core::cache::CacheTier::Memory,
236                    load_time: start.elapsed(),
237                });
238                Ok(())
239            }
240
241            ExecutionPlan::Sequence(steps) => {
242                for step in steps {
243                    step.execute(ctx, filters, cache)?;
244                }
245                Ok(())
246            }
247
248            ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
249
250            ExecutionPlan::Loop {
251                node_id,
252                body,
253                max_iterations,
254            } => {
255                let max = max_iterations.unwrap_or(100);
256                for i in 0..max {
257                    body.execute(ctx, filters, cache)?;
258
259                    // Check termination: if the last executed node produced a Value
260                    // that indicates "done" (true, "done", "stop", or empty), break.
261                    let should_stop = ctx
262                        .execution_order
263                        .last()
264                        .and_then(|last_id| ctx.get(last_id))
265                        .map(|v| match v {
266                            Value::Json(j) => {
267                                j.as_bool() == Some(true)
268                                    || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
269                                    || j.get("done").and_then(|d| d.as_bool()) == Some(true)
270                            }
271                            Value::Empty => true,
272                            _ => false,
273                        })
274                        .unwrap_or(false);
275
276                    if should_stop {
277                        ctx.event_bus.emit(Event::NodeCompleted {
278                            run_id: ctx.run_id.clone(),
279                            node_id: node_id.clone(),
280                            duration: std::time::Duration::ZERO,
281                            output_summary: format!("Loop terminated at iteration {}", i + 1),
282                        });
283                        break;
284                    }
285                }
286                Ok(())
287            }
288
289            ExecutionPlan::Branch { node_id, arms } => {
290                // Execute the branch node first (it produces the condition value)
291                execute_node(node_id, ctx, filters, cache)?;
292
293                // Get the condition result
294                let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
295
296                // Match against arm labels
297                let selected_arm = match &condition {
298                    Value::Json(j) => {
299                        // Try matching by string value, bool, or "branch" field
300                        let selector = j
301                            .as_str()
302                            .map(String::from)
303                            .or_else(|| j.as_bool().map(|b| b.to_string()))
304                            .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
305                            .unwrap_or_else(|| "true".to_string());
306
307                        arms.iter()
308                            .find(|(label, _)| label == &selector)
309                            .or_else(|| {
310                                arms.iter()
311                                    .find(|(label, _)| label == "default" || label == "else")
312                            })
313                            .or_else(|| arms.first())
314                    }
315                    _ => arms.first(),
316                };
317
318                if let Some((label, plan)) = selected_arm {
319                    ctx.event_bus.emit(Event::NodeCompleted {
320                        run_id: ctx.run_id.clone(),
321                        node_id: node_id.clone(),
322                        duration: std::time::Duration::ZERO,
323                        output_summary: format!("Branch selected: {label}"),
324                    });
325                    plan.execute(ctx, filters, cache)?;
326                }
327                Ok(())
328            }
329
330            ExecutionPlan::Remote {
331                node_id,
332                target: _,
333                plan,
334            } => {
335                if let Some(transport) = &ctx.transport {
336                    // Gather input from predecessors
337                    let input = ctx
338                        .graph_info
339                        .predecessors(node_id)
340                        .first()
341                        .and_then(|pred| ctx.get(pred));
342
343                    let result = transport.execute_node(node_id, input)?;
344                    ctx.set(node_id.clone(), result);
345                    Ok(())
346                } else {
347                    // No transport — fall back to local execution
348                    plan.execute(ctx, filters, cache)
349                }
350            }
351
352            ExecutionPlan::Composite { node_ids } => {
353                // Sequential fallback — execute each node in order.
354                // A future Python-aware executor will pass tensors directly.
355                for nid in node_ids {
356                    execute_node(nid, ctx, filters, cache)?;
357                }
358                Ok(())
359            }
360
361            ExecutionPlan::Stream {
362                node_ids,
363                chunk_size,
364            } => execute_stream(node_ids, *chunk_size, ctx, filters, cache),
365
366            _ => {
367                tracing::warn!("Unhandled ExecutionPlan variant");
368                Ok(())
369            }
370        }
371    }
372}
373
374/// Execute a plan (convenience function, delegates to `Executable` trait).
375pub fn execute(
376    plan: &ExecutionPlan,
377    ctx: &mut Context,
378    filters: &FilterLibrary,
379    cache: &dyn CacheStore,
380) -> Result<()> {
381    plan.execute(ctx, filters, cache)
382}
383
384/// Execute a single filter node.
385fn execute_node(
386    node_id: &str,
387    ctx: &mut Context,
388    filters: &FilterLibrary,
389    _cache: &dyn CacheStore,
390) -> Result<()> {
391    let start = Instant::now();
392
393    let filter = filters
394        .get(node_id)
395        .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
396
397    ctx.event_bus.emit(Event::NodeStarted {
398        run_id: ctx.run_id.clone(),
399        node_id: node_id.to_string(),
400        kind: filter.meta().kind,
401    });
402
403    let _span = tracing::info_span!("execute_node", %node_id).entered();
404
405    let input = resolve_input(node_id, ctx);
406    // Borrow state via Arc — cloning the inner Value here would deep-copy
407    // potentially huge tensors (encoder outputs, model weights) on every
408    // forward call. Arc::clone is a cheap atomic increment.
409    let state = filters.get_state(node_id);
410    let state_ref: &Value = state.as_deref().unwrap_or(&Value::Empty);
411
412    // catch_unwind: a panic in a user filter must not crash the process
413    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
414        filter.forward(&input, state_ref)
415    }));
416
417    let result = match result {
418        Ok(inner) => inner,
419        Err(panic) => {
420            let msg = panic
421                .downcast_ref::<String>()
422                .map(|s| s.as_str())
423                .or_else(|| panic.downcast_ref::<&str>().copied())
424                .unwrap_or("unknown panic");
425            tracing::error!(node_id, "filter panicked: {msg}");
426            Err(SomaError::Execution {
427                node_id: node_id.to_string(),
428                message: format!("filter panicked: {msg}"),
429            })
430        }
431    };
432
433    match result {
434        Ok(output) => {
435            let duration = start.elapsed();
436            let summary = format!("{output}");
437            let vv = ctx.maybe_spill(node_id, output);
438            ctx.set_virtual(node_id, vv);
439            ctx.event_bus.emit(Event::NodeCompleted {
440                run_id: ctx.run_id.clone(),
441                node_id: node_id.to_string(),
442                duration,
443                output_summary: summary,
444            });
445            Ok(())
446        }
447        Err(e) => {
448            tracing::error!(node_id, error = %e, "node execution failed");
449            ctx.event_bus.emit(Event::NodeFailed {
450                run_id: ctx.run_id.clone(),
451                node_id: node_id.to_string(),
452                error: e.to_string(),
453            });
454            Err(e)
455        }
456    }
457}
458
459/// Execute parallel branches concurrently using std::thread::scope.
460///
461/// Each branch gets a snapshot of the context. After all branches complete,
462/// their new outputs are merged back into the main context.
463fn execute_parallel(
464    branches: &[ExecutionPlan],
465    ctx: &mut Context,
466    filters: &FilterLibrary,
467    cache: &dyn CacheStore,
468) -> Result<()> {
469    let snapshot_keys: Arc<std::collections::HashSet<String>> =
470        Arc::new(ctx.store.keys().cloned().collect());
471
472    // Use scoped threads for true parallelism without Send requirements
473    let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
474        let handles: Vec<_> = branches
475            .iter()
476            .map(|branch| {
477                let mut branch_ctx = ctx.snapshot();
478                let keys = snapshot_keys.clone();
479                s.spawn(move || {
480                    execute(branch, &mut branch_ctx, filters, cache)?;
481                    let new_entries: Vec<(String, VirtualValue)> = branch_ctx
482                        .store
483                        .into_iter()
484                        .filter(|(k, _)| !keys.contains(k))
485                        .collect();
486                    Ok(new_entries)
487                })
488            })
489            .collect();
490
491        handles.into_iter().map(|h| h.join().unwrap()).collect()
492    });
493
494    // Merge results and propagate first error
495    for result in results {
496        let entries = result?;
497        for (key, vv) in entries {
498            ctx.set_virtual(key, vv);
499        }
500    }
501
502    Ok(())
503}
504
505/// Resolve a VirtualValue to a concrete Value, loading from DataStore if needed.
506fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
507    match vv {
508        VirtualValue::Materialized { value, .. } => Some(value.clone()),
509        VirtualValue::Cached { key, .. } => {
510            // Try to load from DataStore
511            if let Some(store) = data_store {
512                let data_ref = somatize_core::store::DataRef::Cached {
513                    cache_key: key.clone(),
514                };
515                store.get(&data_ref).ok()
516            } else {
517                None
518            }
519        }
520        _ => None,
521    }
522}
523
524/// Resolve the input for a node from the context store using graph topology.
525/// If a predecessor was spilled to DataStore, loads it back.
526pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
527    let preds = ctx.graph_info.predecessors(node_id);
528
529    let resolve_node = |id: &str| -> Option<Value> {
530        ctx.store
531            .get(id)
532            .and_then(|vv| resolve_value(vv, &ctx.data_store))
533    };
534
535    match preds.len() {
536        0 => ctx
537            .execution_order
538            .last()
539            .and_then(|id| resolve_node(id))
540            .unwrap_or(Value::Empty),
541        1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
542        _ => {
543            let mut merged = serde_json::Map::new();
544            for pred_id in preds {
545                if let Some(val) = resolve_node(pred_id) {
546                    let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
547                    merged.insert(pred_id.clone(), json_val);
548                }
549            }
550            Value::json(serde_json::Value::Object(merged))
551        }
552    }
553}
554
555/// Execute a stream plan: chunk input, process through StreamExecutor, concatenate.
556fn execute_stream(
557    node_ids: &[String],
558    chunk_size: usize,
559    ctx: &mut Context,
560    filters: &FilterLibrary,
561    cache: &dyn CacheStore,
562) -> Result<()> {
563    use crate::executors::{FittedFilter, StreamExecutor};
564
565    let start = Instant::now();
566
567    // Build FittedFilter list from the library in plan order.
568    let fitted: Vec<FittedFilter> = node_ids
569        .iter()
570        .map(|nid| {
571            let filter = filters
572                .get(nid)
573                .ok_or_else(|| SomaError::NodeNotFound(nid.clone()))?;
574            let state = filters
575                .get_state(nid)
576                .unwrap_or_else(|| Arc::new(Value::Empty));
577            Ok(FittedFilter {
578                name: nid.clone(),
579                filter,
580                state,
581            })
582        })
583        .collect::<Result<_>>()?;
584
585    // Resolve input from the first node's predecessors.
586    let first_id = node_ids
587        .first()
588        .ok_or_else(|| SomaError::Other("stream plan has no nodes".into()))?;
589    let input = resolve_input(first_id, ctx);
590
591    // Chunk the input along the first tensor dimension.
592    let chunks = chunk_value(&input, chunk_size);
593
594    // Process chunks through the stream executor.
595    let mut executor = StreamExecutor::new(fitted);
596    if let Some(c) = cache_as_arc(cache) {
597        executor = executor.with_cache(c);
598    }
599
600    let last_id = node_ids.last().unwrap();
601
602    // Incrementally concatenate tensor outputs to avoid holding all chunks
603    // in memory simultaneously. Each chunk is dropped after its data is
604    // extracted, keeping peak memory proportional to the final output size
605    // rather than O(n_chunks × chunk_size).
606    let mut all_data: Vec<f64> = Vec::new();
607    let mut result_shape: Option<Vec<usize>> = None;
608    let mut non_tensor_output: Option<Value> = None;
609
610    let mut append_output = |output: Value| {
611        match &output {
612            Value::Tensor { values, shape } => {
613                if result_shape.is_none() {
614                    result_shape = Some(shape.clone());
615                }
616                all_data.extend_from_slice(values.as_slice());
617                // `output` is dropped here, freeing the chunk's allocation
618            }
619            _ => {
620                non_tensor_output = Some(output);
621            }
622        }
623    };
624
625    for (i, chunk) in chunks.into_iter().enumerate() {
626        ctx.event_bus.emit(Event::NodeStarted {
627            run_id: ctx.run_id.clone(),
628            node_id: format!("{last_id}#chunk_{i}"),
629            kind: somatize_core::filter::FilterKind::Stateless,
630        });
631        if let Some(output) = executor.process_chunk(chunk)? {
632            append_output(output);
633        }
634    }
635
636    // Flush barrier filters.
637    if let Some(flushed) = executor.flush()? {
638        append_output(flushed);
639    }
640
641    // Build the final result.
642    let result = if let Some(mut shape) = result_shape {
643        // Tensor output: fix the first dimension to reflect total rows.
644        let row_size: usize = shape.iter().skip(1).product::<usize>().max(1);
645        shape[0] = all_data.len() / row_size;
646        Value::tensor(all_data, shape)
647    } else {
648        non_tensor_output.unwrap_or(Value::Empty)
649    };
650
651    let duration = start.elapsed();
652    ctx.set(last_id.clone(), result);
653    ctx.event_bus.emit(Event::NodeCompleted {
654        run_id: ctx.run_id.clone(),
655        node_id: last_id.clone(),
656        duration,
657        output_summary: format!(
658            "stream: {} chunks through {} filters",
659            executor.chunks_processed(),
660            node_ids.len()
661        ),
662    });
663
664    Ok(())
665}
666
667/// Split a Value::Tensor along the first dimension into chunks.
668fn chunk_value(x: &Value, chunk_size: usize) -> Vec<Value> {
669    match x {
670        Value::Tensor { values, shape } if !values.is_empty() && chunk_size > 0 => {
671            let row_size = if shape.len() > 1 {
672                shape[1..].iter().product()
673            } else {
674                1
675            };
676            let n_rows = shape[0];
677            let mut chunks = Vec::new();
678            for start in (0..n_rows).step_by(chunk_size) {
679                let end = (start + chunk_size).min(n_rows);
680                let flat_start = start * row_size;
681                let flat_end = end * row_size;
682                let chunk_vals = values[flat_start..flat_end].to_vec();
683                let mut chunk_shape = shape.clone();
684                chunk_shape[0] = end - start;
685                chunks.push(Value::tensor(chunk_vals, chunk_shape));
686            }
687            chunks
688        }
689        _ => vec![x.clone()],
690    }
691}
692
693/// Try to wrap the cache reference as an Arc for StreamExecutor.
694/// This is safe because the cache outlives the executor within execute().
695fn cache_as_arc(_cache: &dyn CacheStore) -> Option<Arc<dyn CacheStore>> {
696    // StreamExecutor requires Arc, but we only have a reference.
697    // For local execution we skip the cache in streaming mode — the per-chunk
698    // cache is an optimization, not a correctness requirement.
699    // TODO: pass cache as Option<&dyn CacheStore> to StreamExecutor.
700    None
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use crate::cache::MemoryCache;
707    use somatize_core::cache::CacheKey;
708    use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
709
710    struct DoublerFilter;
711
712    impl Filter for DoublerFilter {
713        fn config_hash(&self) -> CacheKey {
714            CacheKey::from_parts(&[b"Doubler"])
715        }
716        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
717            Ok(Value::Empty)
718        }
719        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
720            match x {
721                Value::Tensor { values, shape } => {
722                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
723                    Ok(Value::tensor(doubled, shape.clone()))
724                }
725                _ => Ok(x.clone()),
726            }
727        }
728        fn meta(&self) -> FilterMeta {
729            FilterMeta {
730                name: "Doubler".into(),
731                kind: FilterKind::Stateless,
732                cacheable: true,
733                differentiable: true,
734                stream_mode: StreamMode::FixedState,
735                distribution: somatize_core::filter::Distribution::Local,
736                input_schema: None,
737                output_schema: None,
738            }
739        }
740
741        fn as_any(&self) -> &dyn std::any::Any {
742            self
743        }
744    }
745
746    struct AdderFilter {
747        amount: f64,
748    }
749
750    impl Filter for AdderFilter {
751        fn config_hash(&self) -> CacheKey {
752            CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
753        }
754        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
755            Ok(Value::Empty)
756        }
757        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
758            match x {
759                Value::Tensor { values, shape } => {
760                    let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
761                    Ok(Value::tensor(added, shape.clone()))
762                }
763                _ => Ok(x.clone()),
764            }
765        }
766        fn meta(&self) -> FilterMeta {
767            FilterMeta {
768                name: "Adder".into(),
769                kind: FilterKind::Stateless,
770                cacheable: true,
771                differentiable: true,
772                stream_mode: StreamMode::FixedState,
773                distribution: somatize_core::filter::Distribution::Local,
774                input_schema: None,
775                output_schema: None,
776            }
777        }
778
779        fn as_any(&self) -> &dyn std::any::Any {
780            self
781        }
782    }
783
784    /// Slow filter that sleeps to verify parallelism.
785    struct SlowFilter {
786        id: String,
787        delay_ms: u64,
788    }
789
790    impl Filter for SlowFilter {
791        fn config_hash(&self) -> CacheKey {
792            CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
793        }
794        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
795            Ok(Value::Empty)
796        }
797        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
798            std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
799            Ok(x.clone())
800        }
801        fn meta(&self) -> FilterMeta {
802            FilterMeta {
803                name: format!("Slow_{}", self.id),
804                kind: FilterKind::Stateless,
805                cacheable: false,
806                differentiable: true,
807                stream_mode: StreamMode::FixedState,
808                distribution: somatize_core::filter::Distribution::Local,
809                input_schema: None,
810                output_schema: None,
811            }
812        }
813
814        fn as_any(&self) -> &dyn std::any::Any {
815            self
816        }
817    }
818
819    fn setup() -> (Arc<EventBus>, MemoryCache) {
820        (Arc::new(EventBus::new(64)), MemoryCache::default())
821    }
822
823    #[test]
824    fn execute_single_node() {
825        let (bus, cache) = setup();
826        let mut ctx = Context::new(bus, "run_1");
827        ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
828        ctx.graph_info
829            .set_predecessors("doubler", vec!["input".into()]);
830
831        let mut filters = FilterLibrary::new();
832        filters.register("doubler", Box::new(DoublerFilter));
833
834        let plan = ExecutionPlan::Execute {
835            node_id: "doubler".into(),
836        };
837
838        execute(&plan, &mut ctx, &filters, &cache).unwrap();
839
840        let result = ctx.get("doubler").unwrap();
841        let (data, _) = result.as_tensor().unwrap();
842        assert_eq!(data, &[2.0, 4.0, 6.0]);
843    }
844
845    #[test]
846    fn execute_sequence_with_graph_info() {
847        let (bus, cache) = setup();
848        let mut ctx = Context::new(bus, "run_1");
849        ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
850
851        let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
852        ctx.graph_info = graph_info;
853
854        let mut filters = FilterLibrary::new();
855        filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
856        filters.register("double", Box::new(DoublerFilter));
857
858        let plan = ExecutionPlan::Sequence(vec![
859            ExecutionPlan::Execute {
860                node_id: "add".into(),
861            },
862            ExecutionPlan::Execute {
863                node_id: "double".into(),
864            },
865        ]);
866
867        execute(&plan, &mut ctx, &filters, &cache).unwrap();
868
869        let result = ctx.get("double").unwrap();
870        let (data, _) = result.as_tensor().unwrap();
871        assert_eq!(data, &[22.0, 24.0]);
872    }
873
874    #[test]
875    fn execute_cached_node() {
876        let (bus, cache) = setup();
877        let key = CacheKey::hash_data(b"cached_result");
878        let cached_value = Value::tensor(vec![99.0], vec![1]);
879        cache.put(&key, &cached_value).unwrap();
880
881        let mut ctx = Context::new(bus, "run_1");
882        let filters = FilterLibrary::new();
883
884        let plan = ExecutionPlan::Cached {
885            node_id: "cached_node".into(),
886            key,
887        };
888
889        execute(&plan, &mut ctx, &filters, &cache).unwrap();
890        assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
891    }
892
893    #[test]
894    fn execute_emits_events() {
895        let bus = Arc::new(EventBus::new(64));
896        let cache = MemoryCache::default();
897        let mut rx = bus.subscribe();
898
899        let mut ctx = Context::new(bus, "run_1");
900        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
901        ctx.graph_info
902            .set_predecessors("double", vec!["input".into()]);
903
904        let mut filters = FilterLibrary::new();
905        filters.register("double", Box::new(DoublerFilter));
906
907        execute(
908            &ExecutionPlan::Execute {
909                node_id: "double".into(),
910            },
911            &mut ctx,
912            &filters,
913            &cache,
914        )
915        .unwrap();
916
917        let e1 = rx.try_recv().unwrap();
918        assert!(matches!(e1, Event::NodeStarted { .. }));
919        let e2 = rx.try_recv().unwrap();
920        assert!(matches!(e2, Event::NodeCompleted { .. }));
921    }
922
923    #[test]
924    fn execute_missing_filter_errors() {
925        let (bus, cache) = setup();
926        let mut ctx = Context::new(bus, "run_1");
927        let filters = FilterLibrary::new();
928
929        let result = execute(
930            &ExecutionPlan::Execute {
931                node_id: "nonexistent".into(),
932            },
933            &mut ctx,
934            &filters,
935            &cache,
936        );
937        assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
938    }
939
940    #[test]
941    fn execute_empty_plan() {
942        let (bus, cache) = setup();
943        let mut ctx = Context::new(bus, "run_1");
944        let filters = FilterLibrary::new();
945        execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
946    }
947
948    #[test]
949    fn execute_parallel_branches_merge_outputs() {
950        let (bus, cache) = setup();
951        let mut ctx = Context::new(bus, "run_1");
952        ctx.set("input", Value::tensor(vec![5.0], vec![1]));
953        ctx.graph_info
954            .set_predecessors("double", vec!["input".into()]);
955        ctx.graph_info.set_predecessors("add", vec!["input".into()]);
956
957        let mut filters = FilterLibrary::new();
958        filters.register("double", Box::new(DoublerFilter));
959        filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
960
961        let plan = ExecutionPlan::Parallel(vec![
962            ExecutionPlan::Execute {
963                node_id: "double".into(),
964            },
965            ExecutionPlan::Execute {
966                node_id: "add".into(),
967            },
968        ]);
969
970        execute(&plan, &mut ctx, &filters, &cache).unwrap();
971
972        let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
973        assert_eq!(double_out, &[10.0]);
974        let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
975        assert_eq!(add_out, &[105.0]);
976    }
977
978    #[test]
979    fn parallel_branches_run_concurrently() {
980        let (bus, cache) = setup();
981        let mut ctx = Context::new(bus, "run_1");
982        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
983        ctx.graph_info
984            .set_predecessors("slow_a", vec!["input".into()]);
985        ctx.graph_info
986            .set_predecessors("slow_b", vec!["input".into()]);
987
988        let mut filters = FilterLibrary::new();
989        filters.register(
990            "slow_a",
991            Box::new(SlowFilter {
992                id: "a".into(),
993                delay_ms: 50,
994            }),
995        );
996        filters.register(
997            "slow_b",
998            Box::new(SlowFilter {
999                id: "b".into(),
1000                delay_ms: 50,
1001            }),
1002        );
1003
1004        let plan = ExecutionPlan::Parallel(vec![
1005            ExecutionPlan::Execute {
1006                node_id: "slow_a".into(),
1007            },
1008            ExecutionPlan::Execute {
1009                node_id: "slow_b".into(),
1010            },
1011        ]);
1012
1013        let start = Instant::now();
1014        execute(&plan, &mut ctx, &filters, &cache).unwrap();
1015        let elapsed = start.elapsed();
1016
1017        // If truly parallel: ~50ms. If sequential: ~100ms.
1018        // Use 90ms as threshold to account for overhead.
1019        assert!(
1020            elapsed.as_millis() < 90,
1021            "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
1022            elapsed.as_millis()
1023        );
1024
1025        assert!(ctx.get("slow_a").is_some());
1026        assert!(ctx.get("slow_b").is_some());
1027    }
1028
1029    #[test]
1030    fn resolve_input_single_predecessor() {
1031        let bus = Arc::new(EventBus::new(8));
1032        let mut ctx = Context::new(bus, "r");
1033        ctx.set("A", Value::tensor(vec![42.0], vec![1]));
1034        ctx.graph_info.set_predecessors("B", vec!["A".into()]);
1035
1036        let input = resolve_input("B", &ctx);
1037        let (data, _) = input.as_tensor().unwrap();
1038        assert_eq!(data, &[42.0]);
1039    }
1040
1041    #[test]
1042    fn resolve_input_multiple_predecessors() {
1043        let bus = Arc::new(EventBus::new(8));
1044        let mut ctx = Context::new(bus, "r");
1045        ctx.set("A", Value::tensor(vec![1.0], vec![1]));
1046        ctx.set("B", Value::tensor(vec![2.0], vec![1]));
1047        ctx.graph_info
1048            .set_predecessors("C", vec!["A".into(), "B".into()]);
1049
1050        let input = resolve_input("C", &ctx);
1051        let json = input.as_json().unwrap();
1052        assert!(json.get("A").is_some());
1053        assert!(json.get("B").is_some());
1054    }
1055
1056    #[test]
1057    fn resolve_input_no_predecessors_fallback() {
1058        let bus = Arc::new(EventBus::new(8));
1059        let mut ctx = Context::new(bus, "r");
1060        ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
1061
1062        let input = resolve_input("root", &ctx);
1063        let (data, _) = input.as_tensor().unwrap();
1064        assert_eq!(data, &[7.0]);
1065    }
1066
1067    #[test]
1068    fn graph_info_from_linear() {
1069        let info = GraphInfo::for_linear(&["a", "b", "c"]);
1070        assert!(info.predecessors("a").is_empty());
1071        assert_eq!(info.predecessors("b"), &["a"]);
1072        assert_eq!(info.predecessors("c"), &["b"]);
1073    }
1074
1075    #[test]
1076    fn execute_stream_chunks_input() {
1077        let (bus, cache) = setup();
1078        let mut ctx = Context::new(bus, "run_stream");
1079        // 6-element input, chunk_size=2 → 3 chunks
1080        ctx.set(
1081            "__input__",
1082            Value::tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]),
1083        );
1084        ctx.graph_info
1085            .set_predecessors("double", vec!["__input__".into()]);
1086
1087        let mut filters = FilterLibrary::new();
1088        filters.register("double", Box::new(DoublerFilter));
1089
1090        let plan = ExecutionPlan::Stream {
1091            node_ids: vec!["double".into()],
1092            chunk_size: 2,
1093        };
1094
1095        execute(&plan, &mut ctx, &filters, &cache).unwrap();
1096
1097        let result = ctx.get("double").unwrap();
1098        let (data, shape) = result.as_tensor().unwrap();
1099        assert_eq!(data, &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
1100        assert_eq!(shape, &[6]);
1101    }
1102
1103    #[test]
1104    fn execute_stream_chain() {
1105        let (bus, cache) = setup();
1106        let mut ctx = Context::new(bus, "run_stream_chain");
1107        ctx.set(
1108            "__input__",
1109            Value::tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]),
1110        );
1111        ctx.graph_info
1112            .set_predecessors("double", vec!["__input__".into()]);
1113        ctx.graph_info
1114            .set_predecessors("add", vec!["double".into()]);
1115
1116        let mut filters = FilterLibrary::new();
1117        filters.register("double", Box::new(DoublerFilter));
1118        filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
1119
1120        let plan = ExecutionPlan::Stream {
1121            node_ids: vec!["double".into(), "add".into()],
1122            chunk_size: 2,
1123        };
1124
1125        execute(&plan, &mut ctx, &filters, &cache).unwrap();
1126
1127        // double → add: [1,2,3,4] → [2,4,6,8] → [12,14,16,18]
1128        let result = ctx.get("add").unwrap();
1129        let (data, shape) = result.as_tensor().unwrap();
1130        assert_eq!(data, &[12.0, 14.0, 16.0, 18.0]);
1131        assert_eq!(shape, &[4]);
1132    }
1133
1134    #[test]
1135    fn execute_stream_single_chunk() {
1136        let (bus, cache) = setup();
1137        let mut ctx = Context::new(bus, "run_stream_single");
1138        ctx.set("__input__", Value::tensor(vec![5.0, 10.0], vec![2]));
1139        ctx.graph_info
1140            .set_predecessors("double", vec!["__input__".into()]);
1141
1142        let mut filters = FilterLibrary::new();
1143        filters.register("double", Box::new(DoublerFilter));
1144
1145        // chunk_size larger than input → single chunk
1146        let plan = ExecutionPlan::Stream {
1147            node_ids: vec!["double".into()],
1148            chunk_size: 1000,
1149        };
1150
1151        execute(&plan, &mut ctx, &filters, &cache).unwrap();
1152
1153        let result = ctx.get("double").unwrap();
1154        let (data, _) = result.as_tensor().unwrap();
1155        assert_eq!(data, &[10.0, 20.0]);
1156    }
1157}