1use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25 predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36 self.predecessors.insert(node_id.into(), preds);
37 }
38
39 pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41 let mut info = Self::new();
42 for node in &graph.nodes {
43 let preds: Vec<String> = graph
44 .predecessors(&node.id)
45 .into_iter()
46 .map(|s| s.to_string())
47 .collect();
48 info.set_predecessors(node.id.clone(), preds);
49 }
50 info
51 }
52
53 pub fn for_linear(node_ids: &[&str]) -> Self {
55 let mut info = Self::new();
56 for (i, &id) in node_ids.iter().enumerate() {
57 let preds = if i > 0 {
58 vec![node_ids[i - 1].to_string()]
59 } else {
60 vec![]
61 };
62 info.set_predecessors(id, preds);
63 }
64 info
65 }
66
67 pub fn predecessors(&self, node_id: &str) -> &[String] {
69 self.predecessors
70 .get(node_id)
71 .map(|v| v.as_slice())
72 .unwrap_or(&[])
73 }
74}
75
76pub struct Context {
82 pub store: HashMap<String, VirtualValue>,
84 pub event_bus: Arc<EventBus>,
86 pub run_id: String,
88 pub execution_order: Vec<String>,
90 pub graph_info: GraphInfo,
92 pub transport: Option<Arc<dyn crate::runner::Transport>>,
94 pub data_store: Option<Arc<dyn DataStore>>,
96 pub spill_threshold: usize,
99}
100
101impl Context {
102 pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
103 Self {
104 store: HashMap::new(),
105 event_bus,
106 run_id: run_id.into(),
107 execution_order: Vec::new(),
108 graph_info: GraphInfo::new(),
109 transport: None,
110 data_store: None,
111 spill_threshold: 0,
112 }
113 }
114
115 pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
116 self.graph_info = info;
117 self
118 }
119
120 pub fn with_transport(mut self, transport: Arc<dyn crate::runner::Transport>) -> Self {
121 self.transport = Some(transport);
122 self
123 }
124
125 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
126 self.data_store = Some(store);
127 self
128 }
129
130 pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
134 self.spill_threshold = bytes;
135 self
136 }
137
138 fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
141 if self.spill_threshold > 0
142 && let Some(store) = &self.data_store
143 {
144 let size = value.size() * 8; if size >= self.spill_threshold {
146 let key = somatize_core::cache::CacheKey::from_parts(&[
147 self.run_id.as_bytes(),
148 node_id.as_bytes(),
149 ]);
150 let vv_for_schema = VirtualValue::materialized(value.clone());
151 let schema = vv_for_schema.schema().clone();
152 if let Ok(_data_ref) = store.put(&key, &value) {
153 tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
154 return VirtualValue::cached(key, schema);
155 }
156 }
157 }
158 VirtualValue::materialized(value)
159 }
160
161 pub fn get(&self, node_id: &str) -> Option<&Value> {
163 self.store.get(node_id).and_then(|vv| vv.as_value())
164 }
165
166 pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
168 self.store.get(node_id)
169 }
170
171 pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
173 let id = node_id.into();
174 self.execution_order.push(id.clone());
175 self.store.insert(id, VirtualValue::materialized(value));
176 }
177
178 pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
180 let id = node_id.into();
181 self.execution_order.push(id.clone());
182 self.store.insert(id, vv);
183 }
184
185 fn snapshot(&self) -> Self {
186 Self {
187 store: self.store.clone(),
188 event_bus: self.event_bus.clone(),
189 run_id: self.run_id.clone(),
190 execution_order: self.execution_order.clone(),
191 graph_info: self.graph_info.clone(),
192 transport: self.transport.clone(),
193 data_store: self.data_store.clone(),
194 spill_threshold: self.spill_threshold,
195 }
196 }
197}
198
199pub trait Executable {
203 fn execute(
204 &self,
205 ctx: &mut Context,
206 filters: &FilterLibrary,
207 cache: &dyn CacheStore,
208 ) -> Result<()>;
209}
210
211impl Executable for ExecutionPlan {
212 fn execute(
213 &self,
214 ctx: &mut Context,
215 filters: &FilterLibrary,
216 cache: &dyn CacheStore,
217 ) -> Result<()> {
218 match self {
219 ExecutionPlan::Empty => Ok(()),
220
221 ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
222
223 ExecutionPlan::Cached { node_id, key } => {
224 let start = Instant::now();
225 let value = cache.get(key)?.ok_or_else(|| {
226 SomaError::Cache(format!(
227 "expected cached value for node `{node_id}` not found"
228 ))
229 })?;
230 ctx.set(node_id.clone(), value);
231 ctx.event_bus.emit(Event::NodeCacheHit {
232 run_id: ctx.run_id.clone(),
233 node_id: node_id.clone(),
234 key: key.clone(),
235 tier: somatize_core::cache::CacheTier::Memory,
236 load_time: start.elapsed(),
237 });
238 Ok(())
239 }
240
241 ExecutionPlan::Sequence(steps) => {
242 for step in steps {
243 step.execute(ctx, filters, cache)?;
244 }
245 Ok(())
246 }
247
248 ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
249
250 ExecutionPlan::Loop {
251 node_id,
252 body,
253 max_iterations,
254 } => {
255 let max = max_iterations.unwrap_or(100);
256 for i in 0..max {
257 body.execute(ctx, filters, cache)?;
258
259 let should_stop = ctx
262 .execution_order
263 .last()
264 .and_then(|last_id| ctx.get(last_id))
265 .map(|v| match v {
266 Value::Json(j) => {
267 j.as_bool() == Some(true)
268 || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
269 || j.get("done").and_then(|d| d.as_bool()) == Some(true)
270 }
271 Value::Empty => true,
272 _ => false,
273 })
274 .unwrap_or(false);
275
276 if should_stop {
277 ctx.event_bus.emit(Event::NodeCompleted {
278 run_id: ctx.run_id.clone(),
279 node_id: node_id.clone(),
280 duration: std::time::Duration::ZERO,
281 output_summary: format!("Loop terminated at iteration {}", i + 1),
282 });
283 break;
284 }
285 }
286 Ok(())
287 }
288
289 ExecutionPlan::Branch { node_id, arms } => {
290 execute_node(node_id, ctx, filters, cache)?;
292
293 let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
295
296 let selected_arm = match &condition {
298 Value::Json(j) => {
299 let selector = j
301 .as_str()
302 .map(String::from)
303 .or_else(|| j.as_bool().map(|b| b.to_string()))
304 .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
305 .unwrap_or_else(|| "true".to_string());
306
307 arms.iter()
308 .find(|(label, _)| label == &selector)
309 .or_else(|| {
310 arms.iter()
311 .find(|(label, _)| label == "default" || label == "else")
312 })
313 .or_else(|| arms.first())
314 }
315 _ => arms.first(),
316 };
317
318 if let Some((label, plan)) = selected_arm {
319 ctx.event_bus.emit(Event::NodeCompleted {
320 run_id: ctx.run_id.clone(),
321 node_id: node_id.clone(),
322 duration: std::time::Duration::ZERO,
323 output_summary: format!("Branch selected: {label}"),
324 });
325 plan.execute(ctx, filters, cache)?;
326 }
327 Ok(())
328 }
329
330 ExecutionPlan::Remote {
331 node_id,
332 target: _,
333 plan,
334 } => {
335 if let Some(transport) = &ctx.transport {
336 let input = ctx
338 .graph_info
339 .predecessors(node_id)
340 .first()
341 .and_then(|pred| ctx.get(pred));
342
343 let result = transport.execute_node(node_id, input)?;
344 ctx.set(node_id.clone(), result);
345 Ok(())
346 } else {
347 plan.execute(ctx, filters, cache)
349 }
350 }
351
352 ExecutionPlan::Composite { node_ids } => {
353 for nid in node_ids {
356 execute_node(nid, ctx, filters, cache)?;
357 }
358 Ok(())
359 }
360
361 ExecutionPlan::Stream {
362 node_ids,
363 chunk_size,
364 } => execute_stream(node_ids, *chunk_size, ctx, filters, cache),
365
366 _ => {
367 tracing::warn!("Unhandled ExecutionPlan variant");
368 Ok(())
369 }
370 }
371 }
372}
373
374pub fn execute(
376 plan: &ExecutionPlan,
377 ctx: &mut Context,
378 filters: &FilterLibrary,
379 cache: &dyn CacheStore,
380) -> Result<()> {
381 plan.execute(ctx, filters, cache)
382}
383
384fn execute_node(
386 node_id: &str,
387 ctx: &mut Context,
388 filters: &FilterLibrary,
389 _cache: &dyn CacheStore,
390) -> Result<()> {
391 let start = Instant::now();
392
393 let filter = filters
394 .get(node_id)
395 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
396
397 ctx.event_bus.emit(Event::NodeStarted {
398 run_id: ctx.run_id.clone(),
399 node_id: node_id.to_string(),
400 kind: filter.meta().kind,
401 });
402
403 let _span = tracing::info_span!("execute_node", %node_id).entered();
404
405 let input = resolve_input(node_id, ctx);
406 let state = filters.get_state(node_id);
410 let state_ref: &Value = state.as_deref().unwrap_or(&Value::Empty);
411
412 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
414 filter.forward(&input, state_ref)
415 }));
416
417 let result = match result {
418 Ok(inner) => inner,
419 Err(panic) => {
420 let msg = panic
421 .downcast_ref::<String>()
422 .map(|s| s.as_str())
423 .or_else(|| panic.downcast_ref::<&str>().copied())
424 .unwrap_or("unknown panic");
425 tracing::error!(node_id, "filter panicked: {msg}");
426 Err(SomaError::Execution {
427 node_id: node_id.to_string(),
428 message: format!("filter panicked: {msg}"),
429 })
430 }
431 };
432
433 match result {
434 Ok(output) => {
435 let duration = start.elapsed();
436 let summary = format!("{output}");
437 let vv = ctx.maybe_spill(node_id, output);
438 ctx.set_virtual(node_id, vv);
439 ctx.event_bus.emit(Event::NodeCompleted {
440 run_id: ctx.run_id.clone(),
441 node_id: node_id.to_string(),
442 duration,
443 output_summary: summary,
444 });
445 Ok(())
446 }
447 Err(e) => {
448 tracing::error!(node_id, error = %e, "node execution failed");
449 ctx.event_bus.emit(Event::NodeFailed {
450 run_id: ctx.run_id.clone(),
451 node_id: node_id.to_string(),
452 error: e.to_string(),
453 });
454 Err(e)
455 }
456 }
457}
458
459fn execute_parallel(
464 branches: &[ExecutionPlan],
465 ctx: &mut Context,
466 filters: &FilterLibrary,
467 cache: &dyn CacheStore,
468) -> Result<()> {
469 let snapshot_keys: Arc<std::collections::HashSet<String>> =
470 Arc::new(ctx.store.keys().cloned().collect());
471
472 let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
474 let handles: Vec<_> = branches
475 .iter()
476 .map(|branch| {
477 let mut branch_ctx = ctx.snapshot();
478 let keys = snapshot_keys.clone();
479 s.spawn(move || {
480 execute(branch, &mut branch_ctx, filters, cache)?;
481 let new_entries: Vec<(String, VirtualValue)> = branch_ctx
482 .store
483 .into_iter()
484 .filter(|(k, _)| !keys.contains(k))
485 .collect();
486 Ok(new_entries)
487 })
488 })
489 .collect();
490
491 handles.into_iter().map(|h| h.join().unwrap()).collect()
492 });
493
494 for result in results {
496 let entries = result?;
497 for (key, vv) in entries {
498 ctx.set_virtual(key, vv);
499 }
500 }
501
502 Ok(())
503}
504
505fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
507 match vv {
508 VirtualValue::Materialized { value, .. } => Some(value.clone()),
509 VirtualValue::Cached { key, .. } => {
510 if let Some(store) = data_store {
512 let data_ref = somatize_core::store::DataRef::Cached {
513 cache_key: key.clone(),
514 };
515 store.get(&data_ref).ok()
516 } else {
517 None
518 }
519 }
520 _ => None,
521 }
522}
523
524pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
527 let preds = ctx.graph_info.predecessors(node_id);
528
529 let resolve_node = |id: &str| -> Option<Value> {
530 ctx.store
531 .get(id)
532 .and_then(|vv| resolve_value(vv, &ctx.data_store))
533 };
534
535 match preds.len() {
536 0 => ctx
537 .execution_order
538 .last()
539 .and_then(|id| resolve_node(id))
540 .unwrap_or(Value::Empty),
541 1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
542 _ => {
543 let mut merged = serde_json::Map::new();
544 for pred_id in preds {
545 if let Some(val) = resolve_node(pred_id) {
546 let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
547 merged.insert(pred_id.clone(), json_val);
548 }
549 }
550 Value::Json(serde_json::Value::Object(merged))
551 }
552 }
553}
554
555fn execute_stream(
557 node_ids: &[String],
558 chunk_size: usize,
559 ctx: &mut Context,
560 filters: &FilterLibrary,
561 cache: &dyn CacheStore,
562) -> Result<()> {
563 use crate::executors::{FittedFilter, StreamExecutor, materialize_buffer};
564
565 let start = Instant::now();
566
567 let fitted: Vec<FittedFilter> = node_ids
569 .iter()
570 .map(|nid| {
571 let filter = filters
572 .get(nid)
573 .ok_or_else(|| SomaError::NodeNotFound(nid.clone()))?;
574 let state = filters
575 .get_state(nid)
576 .map(|arc| (*arc).clone())
577 .unwrap_or(Value::Empty);
578 Ok(FittedFilter {
579 name: nid.clone(),
580 filter,
581 state,
582 })
583 })
584 .collect::<Result<_>>()?;
585
586 let first_id = node_ids
588 .first()
589 .ok_or_else(|| SomaError::Other("stream plan has no nodes".into()))?;
590 let input = resolve_input(first_id, ctx);
591
592 let chunks = chunk_value(&input, chunk_size);
594
595 let mut executor = StreamExecutor::new(fitted);
597 if let Some(c) = cache_as_arc(cache) {
598 executor = executor.with_cache(c);
599 }
600
601 let last_id = node_ids.last().unwrap();
602
603 let mut outputs: Vec<Value> = Vec::new();
604 for (i, chunk) in chunks.into_iter().enumerate() {
605 ctx.event_bus.emit(Event::NodeStarted {
606 run_id: ctx.run_id.clone(),
607 node_id: format!("{last_id}#chunk_{i}"),
608 kind: somatize_core::filter::FilterKind::Stateless,
609 });
610 if let Some(output) = executor.process_chunk(chunk)? {
611 outputs.push(output);
612 }
613 }
614
615 if let Some(flushed) = executor.flush()? {
617 outputs.push(flushed);
618 }
619
620 let result = if outputs.len() == 1 {
622 outputs.into_iter().next().unwrap()
623 } else if outputs.is_empty() {
624 Value::Empty
625 } else {
626 materialize_buffer(&outputs)?
627 };
628
629 let duration = start.elapsed();
630 ctx.set(last_id.clone(), result);
631 ctx.event_bus.emit(Event::NodeCompleted {
632 run_id: ctx.run_id.clone(),
633 node_id: last_id.clone(),
634 duration,
635 output_summary: format!(
636 "stream: {} chunks through {} filters",
637 executor.chunks_processed(),
638 node_ids.len()
639 ),
640 });
641
642 Ok(())
643}
644
645fn chunk_value(x: &Value, chunk_size: usize) -> Vec<Value> {
647 match x {
648 Value::Tensor { values, shape } if !values.is_empty() && chunk_size > 0 => {
649 let row_size = if shape.len() > 1 {
650 shape[1..].iter().product()
651 } else {
652 1
653 };
654 let n_rows = shape[0];
655 let mut chunks = Vec::new();
656 for start in (0..n_rows).step_by(chunk_size) {
657 let end = (start + chunk_size).min(n_rows);
658 let flat_start = start * row_size;
659 let flat_end = end * row_size;
660 let chunk_vals = values[flat_start..flat_end].to_vec();
661 let mut chunk_shape = shape.clone();
662 chunk_shape[0] = end - start;
663 chunks.push(Value::tensor(chunk_vals, chunk_shape));
664 }
665 chunks
666 }
667 _ => vec![x.clone()],
668 }
669}
670
671fn cache_as_arc(_cache: &dyn CacheStore) -> Option<Arc<dyn CacheStore>> {
674 None
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684 use crate::cache::MemoryCache;
685 use somatize_core::cache::CacheKey;
686 use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
687
688 struct DoublerFilter;
689
690 impl Filter for DoublerFilter {
691 fn config_hash(&self) -> CacheKey {
692 CacheKey::from_parts(&[b"Doubler"])
693 }
694 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
695 Ok(Value::Empty)
696 }
697 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
698 match x {
699 Value::Tensor { values, shape } => {
700 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
701 Ok(Value::tensor(doubled, shape.clone()))
702 }
703 _ => Ok(x.clone()),
704 }
705 }
706 fn meta(&self) -> FilterMeta {
707 FilterMeta {
708 name: "Doubler".into(),
709 kind: FilterKind::Stateless,
710 cacheable: true,
711 differentiable: true,
712 stream_mode: StreamMode::FixedState,
713 distribution: somatize_core::filter::Distribution::Local,
714 input_schema: None,
715 output_schema: None,
716 }
717 }
718
719 fn as_any(&self) -> &dyn std::any::Any {
720 self
721 }
722 }
723
724 struct AdderFilter {
725 amount: f64,
726 }
727
728 impl Filter for AdderFilter {
729 fn config_hash(&self) -> CacheKey {
730 CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
731 }
732 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
733 Ok(Value::Empty)
734 }
735 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
736 match x {
737 Value::Tensor { values, shape } => {
738 let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
739 Ok(Value::tensor(added, shape.clone()))
740 }
741 _ => Ok(x.clone()),
742 }
743 }
744 fn meta(&self) -> FilterMeta {
745 FilterMeta {
746 name: "Adder".into(),
747 kind: FilterKind::Stateless,
748 cacheable: true,
749 differentiable: true,
750 stream_mode: StreamMode::FixedState,
751 distribution: somatize_core::filter::Distribution::Local,
752 input_schema: None,
753 output_schema: None,
754 }
755 }
756
757 fn as_any(&self) -> &dyn std::any::Any {
758 self
759 }
760 }
761
762 struct SlowFilter {
764 id: String,
765 delay_ms: u64,
766 }
767
768 impl Filter for SlowFilter {
769 fn config_hash(&self) -> CacheKey {
770 CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
771 }
772 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
773 Ok(Value::Empty)
774 }
775 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
776 std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
777 Ok(x.clone())
778 }
779 fn meta(&self) -> FilterMeta {
780 FilterMeta {
781 name: format!("Slow_{}", self.id),
782 kind: FilterKind::Stateless,
783 cacheable: false,
784 differentiable: true,
785 stream_mode: StreamMode::FixedState,
786 distribution: somatize_core::filter::Distribution::Local,
787 input_schema: None,
788 output_schema: None,
789 }
790 }
791
792 fn as_any(&self) -> &dyn std::any::Any {
793 self
794 }
795 }
796
797 fn setup() -> (Arc<EventBus>, MemoryCache) {
798 (Arc::new(EventBus::new(64)), MemoryCache::default())
799 }
800
801 #[test]
802 fn execute_single_node() {
803 let (bus, cache) = setup();
804 let mut ctx = Context::new(bus, "run_1");
805 ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
806 ctx.graph_info
807 .set_predecessors("doubler", vec!["input".into()]);
808
809 let mut filters = FilterLibrary::new();
810 filters.register("doubler", Box::new(DoublerFilter));
811
812 let plan = ExecutionPlan::Execute {
813 node_id: "doubler".into(),
814 };
815
816 execute(&plan, &mut ctx, &filters, &cache).unwrap();
817
818 let result = ctx.get("doubler").unwrap();
819 let (data, _) = result.as_tensor().unwrap();
820 assert_eq!(data, &[2.0, 4.0, 6.0]);
821 }
822
823 #[test]
824 fn execute_sequence_with_graph_info() {
825 let (bus, cache) = setup();
826 let mut ctx = Context::new(bus, "run_1");
827 ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
828
829 let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
830 ctx.graph_info = graph_info;
831
832 let mut filters = FilterLibrary::new();
833 filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
834 filters.register("double", Box::new(DoublerFilter));
835
836 let plan = ExecutionPlan::Sequence(vec![
837 ExecutionPlan::Execute {
838 node_id: "add".into(),
839 },
840 ExecutionPlan::Execute {
841 node_id: "double".into(),
842 },
843 ]);
844
845 execute(&plan, &mut ctx, &filters, &cache).unwrap();
846
847 let result = ctx.get("double").unwrap();
848 let (data, _) = result.as_tensor().unwrap();
849 assert_eq!(data, &[22.0, 24.0]);
850 }
851
852 #[test]
853 fn execute_cached_node() {
854 let (bus, cache) = setup();
855 let key = CacheKey::hash_data(b"cached_result");
856 let cached_value = Value::tensor(vec![99.0], vec![1]);
857 cache.put(&key, &cached_value).unwrap();
858
859 let mut ctx = Context::new(bus, "run_1");
860 let filters = FilterLibrary::new();
861
862 let plan = ExecutionPlan::Cached {
863 node_id: "cached_node".into(),
864 key,
865 };
866
867 execute(&plan, &mut ctx, &filters, &cache).unwrap();
868 assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
869 }
870
871 #[test]
872 fn execute_emits_events() {
873 let bus = Arc::new(EventBus::new(64));
874 let cache = MemoryCache::default();
875 let mut rx = bus.subscribe();
876
877 let mut ctx = Context::new(bus, "run_1");
878 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
879 ctx.graph_info
880 .set_predecessors("double", vec!["input".into()]);
881
882 let mut filters = FilterLibrary::new();
883 filters.register("double", Box::new(DoublerFilter));
884
885 execute(
886 &ExecutionPlan::Execute {
887 node_id: "double".into(),
888 },
889 &mut ctx,
890 &filters,
891 &cache,
892 )
893 .unwrap();
894
895 let e1 = rx.try_recv().unwrap();
896 assert!(matches!(e1, Event::NodeStarted { .. }));
897 let e2 = rx.try_recv().unwrap();
898 assert!(matches!(e2, Event::NodeCompleted { .. }));
899 }
900
901 #[test]
902 fn execute_missing_filter_errors() {
903 let (bus, cache) = setup();
904 let mut ctx = Context::new(bus, "run_1");
905 let filters = FilterLibrary::new();
906
907 let result = execute(
908 &ExecutionPlan::Execute {
909 node_id: "nonexistent".into(),
910 },
911 &mut ctx,
912 &filters,
913 &cache,
914 );
915 assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
916 }
917
918 #[test]
919 fn execute_empty_plan() {
920 let (bus, cache) = setup();
921 let mut ctx = Context::new(bus, "run_1");
922 let filters = FilterLibrary::new();
923 execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
924 }
925
926 #[test]
927 fn execute_parallel_branches_merge_outputs() {
928 let (bus, cache) = setup();
929 let mut ctx = Context::new(bus, "run_1");
930 ctx.set("input", Value::tensor(vec![5.0], vec![1]));
931 ctx.graph_info
932 .set_predecessors("double", vec!["input".into()]);
933 ctx.graph_info.set_predecessors("add", vec!["input".into()]);
934
935 let mut filters = FilterLibrary::new();
936 filters.register("double", Box::new(DoublerFilter));
937 filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
938
939 let plan = ExecutionPlan::Parallel(vec![
940 ExecutionPlan::Execute {
941 node_id: "double".into(),
942 },
943 ExecutionPlan::Execute {
944 node_id: "add".into(),
945 },
946 ]);
947
948 execute(&plan, &mut ctx, &filters, &cache).unwrap();
949
950 let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
951 assert_eq!(double_out, &[10.0]);
952 let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
953 assert_eq!(add_out, &[105.0]);
954 }
955
956 #[test]
957 fn parallel_branches_run_concurrently() {
958 let (bus, cache) = setup();
959 let mut ctx = Context::new(bus, "run_1");
960 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
961 ctx.graph_info
962 .set_predecessors("slow_a", vec!["input".into()]);
963 ctx.graph_info
964 .set_predecessors("slow_b", vec!["input".into()]);
965
966 let mut filters = FilterLibrary::new();
967 filters.register(
968 "slow_a",
969 Box::new(SlowFilter {
970 id: "a".into(),
971 delay_ms: 50,
972 }),
973 );
974 filters.register(
975 "slow_b",
976 Box::new(SlowFilter {
977 id: "b".into(),
978 delay_ms: 50,
979 }),
980 );
981
982 let plan = ExecutionPlan::Parallel(vec![
983 ExecutionPlan::Execute {
984 node_id: "slow_a".into(),
985 },
986 ExecutionPlan::Execute {
987 node_id: "slow_b".into(),
988 },
989 ]);
990
991 let start = Instant::now();
992 execute(&plan, &mut ctx, &filters, &cache).unwrap();
993 let elapsed = start.elapsed();
994
995 assert!(
998 elapsed.as_millis() < 90,
999 "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
1000 elapsed.as_millis()
1001 );
1002
1003 assert!(ctx.get("slow_a").is_some());
1004 assert!(ctx.get("slow_b").is_some());
1005 }
1006
1007 #[test]
1008 fn resolve_input_single_predecessor() {
1009 let bus = Arc::new(EventBus::new(8));
1010 let mut ctx = Context::new(bus, "r");
1011 ctx.set("A", Value::tensor(vec![42.0], vec![1]));
1012 ctx.graph_info.set_predecessors("B", vec!["A".into()]);
1013
1014 let input = resolve_input("B", &ctx);
1015 let (data, _) = input.as_tensor().unwrap();
1016 assert_eq!(data, &[42.0]);
1017 }
1018
1019 #[test]
1020 fn resolve_input_multiple_predecessors() {
1021 let bus = Arc::new(EventBus::new(8));
1022 let mut ctx = Context::new(bus, "r");
1023 ctx.set("A", Value::tensor(vec![1.0], vec![1]));
1024 ctx.set("B", Value::tensor(vec![2.0], vec![1]));
1025 ctx.graph_info
1026 .set_predecessors("C", vec!["A".into(), "B".into()]);
1027
1028 let input = resolve_input("C", &ctx);
1029 let json = input.as_json().unwrap();
1030 assert!(json.get("A").is_some());
1031 assert!(json.get("B").is_some());
1032 }
1033
1034 #[test]
1035 fn resolve_input_no_predecessors_fallback() {
1036 let bus = Arc::new(EventBus::new(8));
1037 let mut ctx = Context::new(bus, "r");
1038 ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
1039
1040 let input = resolve_input("root", &ctx);
1041 let (data, _) = input.as_tensor().unwrap();
1042 assert_eq!(data, &[7.0]);
1043 }
1044
1045 #[test]
1046 fn graph_info_from_linear() {
1047 let info = GraphInfo::for_linear(&["a", "b", "c"]);
1048 assert!(info.predecessors("a").is_empty());
1049 assert_eq!(info.predecessors("b"), &["a"]);
1050 assert_eq!(info.predecessors("c"), &["b"]);
1051 }
1052
1053 #[test]
1054 fn execute_stream_chunks_input() {
1055 let (bus, cache) = setup();
1056 let mut ctx = Context::new(bus, "run_stream");
1057 ctx.set(
1059 "__input__",
1060 Value::tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]),
1061 );
1062 ctx.graph_info
1063 .set_predecessors("double", vec!["__input__".into()]);
1064
1065 let mut filters = FilterLibrary::new();
1066 filters.register("double", Box::new(DoublerFilter));
1067
1068 let plan = ExecutionPlan::Stream {
1069 node_ids: vec!["double".into()],
1070 chunk_size: 2,
1071 };
1072
1073 execute(&plan, &mut ctx, &filters, &cache).unwrap();
1074
1075 let result = ctx.get("double").unwrap();
1076 let (data, shape) = result.as_tensor().unwrap();
1077 assert_eq!(data, &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
1078 assert_eq!(shape, &[6]);
1079 }
1080
1081 #[test]
1082 fn execute_stream_chain() {
1083 let (bus, cache) = setup();
1084 let mut ctx = Context::new(bus, "run_stream_chain");
1085 ctx.set(
1086 "__input__",
1087 Value::tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]),
1088 );
1089 ctx.graph_info
1090 .set_predecessors("double", vec!["__input__".into()]);
1091 ctx.graph_info
1092 .set_predecessors("add", vec!["double".into()]);
1093
1094 let mut filters = FilterLibrary::new();
1095 filters.register("double", Box::new(DoublerFilter));
1096 filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
1097
1098 let plan = ExecutionPlan::Stream {
1099 node_ids: vec!["double".into(), "add".into()],
1100 chunk_size: 2,
1101 };
1102
1103 execute(&plan, &mut ctx, &filters, &cache).unwrap();
1104
1105 let result = ctx.get("add").unwrap();
1107 let (data, shape) = result.as_tensor().unwrap();
1108 assert_eq!(data, &[12.0, 14.0, 16.0, 18.0]);
1109 assert_eq!(shape, &[4]);
1110 }
1111
1112 #[test]
1113 fn execute_stream_single_chunk() {
1114 let (bus, cache) = setup();
1115 let mut ctx = Context::new(bus, "run_stream_single");
1116 ctx.set("__input__", Value::tensor(vec![5.0, 10.0], vec![2]));
1117 ctx.graph_info
1118 .set_predecessors("double", vec!["__input__".into()]);
1119
1120 let mut filters = FilterLibrary::new();
1121 filters.register("double", Box::new(DoublerFilter));
1122
1123 let plan = ExecutionPlan::Stream {
1125 node_ids: vec!["double".into()],
1126 chunk_size: 1000,
1127 };
1128
1129 execute(&plan, &mut ctx, &filters, &cache).unwrap();
1130
1131 let result = ctx.get("double").unwrap();
1132 let (data, _) = result.as_tensor().unwrap();
1133 assert_eq!(data, &[10.0, 20.0]);
1134 }
1135}