somatize_runtime/runner/
local.rs1use super::Runner;
7use crate::EventBus;
8use crate::executor::{Context, Executable, GraphInfo};
9use crate::filter_library::FilterLibrary;
10
11use somatize_compiler::ExecutionPlan;
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::util::timestamp_id;
17use somatize_core::value::Value;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21pub struct LocalRunner;
23
24impl Runner for LocalRunner {
25 fn fit(
26 &self,
27 plan: &ExecutionPlan,
28 filters: &FilterLibrary,
29 cache: &dyn CacheStore,
30 event_bus: &Arc<EventBus>,
31 input: &Value,
32 y: Option<&Value>,
33 ) -> Result<(Value, HashMap<String, Value>)> {
34 let node_id_refs = plan.node_ids();
35 let node_ids: Vec<String> = node_id_refs.iter().map(|s| s.to_string()).collect();
36 let graph_info = GraphInfo::for_linear(&node_id_refs);
37 let run_id = timestamp_id("fit");
38 let mut outputs: HashMap<String, Value> = HashMap::new();
39 let mut trained_states: HashMap<String, Value> = HashMap::new();
40
41 if let Some(first) = node_ids.first() {
43 outputs.insert(format!("__input_{first}"), input.clone());
44 }
45
46 for node_id in &node_ids {
47 let filter = filters
48 .get(node_id)
49 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
50
51 let meta = filter.meta();
52
53 event_bus.emit(Event::NodeStarted {
54 run_id: run_id.clone(),
55 node_id: node_id.to_string(),
56 kind: meta.kind,
57 });
58
59 let preds = graph_info.predecessors(node_id);
61 let node_input = match preds.len() {
62 0 => input.clone(),
63 1 => outputs
64 .get(&preds[0])
65 .cloned()
66 .unwrap_or_else(|| input.clone()),
67 _ => {
68 let mut merged = serde_json::Map::new();
69 for pred_id in preds {
70 if let Some(val) = outputs.get(pred_id.as_str()) {
71 let json_val =
72 serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
73 merged.insert(pred_id.clone(), json_val);
74 }
75 }
76 Value::Json(serde_json::Value::Object(merged))
77 }
78 };
79
80 let start = std::time::Instant::now();
81
82 let state = if meta.kind == FilterKind::Trainable {
84 let data_hash =
85 CacheKey::hash_data(&serde_json::to_vec(&node_input).unwrap_or_default());
86 let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
87
88 let s = if let Some(cached) = cache.get(&state_key)? {
89 cached
90 } else {
91 let fitted = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
92 filter.fit(&node_input, y)
93 }))
94 .map_err(|panic| {
95 let msg = panic
96 .downcast_ref::<String>()
97 .map(|s| s.as_str())
98 .or_else(|| panic.downcast_ref::<&str>().copied())
99 .unwrap_or("unknown panic");
100 SomaError::Execution {
101 node_id: node_id.clone(),
102 message: format!("fit panicked: {msg}"),
103 }
104 })??;
105 let _ = cache.put(&state_key, &fitted);
106 fitted
107 };
108 trained_states.insert(node_id.clone(), s.clone());
109 s
110 } else {
111 filters.get_state(node_id).cloned().unwrap_or(Value::Empty)
112 };
113
114 let output = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
116 filter.forward(&node_input, &state)
117 }))
118 .map_err(|panic| {
119 let msg = panic
120 .downcast_ref::<String>()
121 .map(|s| s.as_str())
122 .or_else(|| panic.downcast_ref::<&str>().copied())
123 .unwrap_or("unknown panic");
124 SomaError::Execution {
125 node_id: node_id.clone(),
126 message: format!("forward panicked: {msg}"),
127 }
128 })??;
129
130 event_bus.emit(Event::NodeCompleted {
131 run_id: run_id.clone(),
132 node_id: node_id.to_string(),
133 duration: start.elapsed(),
134 output_summary: format!("{output}"),
135 });
136
137 outputs.insert(node_id.clone(), output);
138 }
139
140 let last_output = outputs.values().last().cloned().unwrap_or(Value::Empty);
141
142 for (id, state) in &trained_states {
145 outputs.insert(format!("__state_{id}"), state.clone());
146 }
147 Ok((last_output, outputs))
148 }
149
150 fn forward(
151 &self,
152 plan: &ExecutionPlan,
153 filters: &FilterLibrary,
154 cache: &dyn CacheStore,
155 event_bus: &Arc<EventBus>,
156 input: &Value,
157 ) -> Result<Value> {
158 let node_ids = plan.node_ids();
159 let graph_info = GraphInfo::for_linear(&node_ids);
160
161 let mut ctx =
162 Context::new(event_bus.clone(), timestamp_id("forward")).with_graph_info(graph_info);
163
164 if let Some(first) = node_ids.first() {
166 ctx.set(format!("__input_{first}"), input.clone());
167 }
168 ctx.set("__input__", input.clone());
169
170 plan.execute(&mut ctx, filters, cache)?;
171
172 ctx.execution_order
174 .last()
175 .and_then(|id| ctx.store.remove(id))
176 .and_then(|vv| vv.as_value().cloned())
177 .ok_or_else(|| SomaError::Other("no output produced".into()))
178 }
179}