somatize_runtime/runner/
local.rs1use super::Runner;
7use crate::EventBus;
8use crate::executor::{Context, Executable, GraphInfo};
9use crate::filter_library::FilterLibrary;
10
11use somatize_compiler::ExecutionPlan;
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::util::timestamp_id;
17use somatize_core::value::Value;
18use std::borrow::Cow;
19use std::collections::HashMap;
20use std::sync::Arc;
21
22pub struct LocalRunner;
24
25impl LocalRunner {
26 fn fit_sequence(
28 &self,
29 steps: &[ExecutionPlan],
30 filters: &FilterLibrary,
31 cache: &dyn CacheStore,
32 event_bus: &Arc<EventBus>,
33 input: &Value,
34 y: Option<&Value>,
35 ) -> Result<(Value, HashMap<String, Value>)> {
36 let mut current_input = input.clone();
37 let mut all_outputs = HashMap::new();
38
39 for step in steps {
40 let handled = if let ExecutionPlan::Composite { node_ids } = step
41 && let Some(filter) = filters.get(&node_ids[0])
42 && let Some(result) = filter.composite_fit(node_ids, ¤t_input, y)
43 {
44 let (output, states) = result?;
45 for (id, state) in &states {
46 all_outputs.insert(format!("__state_{id}"), state.clone());
47 }
48 if let Some(last_id) = node_ids.last() {
49 all_outputs.insert(last_id.clone(), output.clone());
50 }
51 current_input = output;
52 true
53 } else {
54 false
55 };
56
57 if !handled {
58 let sub_result = <Self as Runner>::fit(
59 self,
60 step,
61 filters,
62 cache,
63 event_bus,
64 ¤t_input,
65 y,
66 )?;
67 current_input = sub_result.0;
68 all_outputs.extend(sub_result.1);
69 }
70 }
71
72 Ok((current_input, all_outputs))
73 }
74}
75
76impl Runner for LocalRunner {
77 fn fit(
78 &self,
79 plan: &ExecutionPlan,
80 filters: &FilterLibrary,
81 cache: &dyn CacheStore,
82 event_bus: &Arc<EventBus>,
83 input: &Value,
84 y: Option<&Value>,
85 ) -> Result<(Value, HashMap<String, Value>)> {
86 if let ExecutionPlan::Composite { node_ids } = plan
88 && let Some(filter) = filters.get(&node_ids[0])
89 && let Some(result) = filter.composite_fit(node_ids, input, y)
90 {
91 return result;
92 }
94
95 if let ExecutionPlan::Sequence(steps) = plan {
97 return self.fit_sequence(steps, filters, cache, event_bus, input, y);
98 }
99
100 let node_id_refs = plan.node_ids();
102 let node_ids: Vec<String> = node_id_refs.iter().map(|s| s.to_string()).collect();
103 let graph_info = GraphInfo::for_linear(&node_id_refs);
104 let run_id = timestamp_id("fit");
105 let mut outputs: HashMap<String, Value> = HashMap::new();
106 let mut trained_states: HashMap<String, Value> = HashMap::new();
107
108 if let Some(first) = node_ids.first() {
109 outputs.insert(format!("__input_{first}"), input.clone());
110 }
111
112 for node_id in &node_ids {
113 let filter = filters
114 .get(node_id)
115 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
116
117 let meta = filter.meta();
118
119 event_bus.emit(Event::NodeStarted {
120 run_id: run_id.clone(),
121 node_id: node_id.to_string(),
122 kind: meta.kind,
123 });
124
125 let preds = graph_info.predecessors(node_id);
127 let node_input = match preds.len() {
128 0 => input.clone(),
129 1 => outputs
130 .get(&preds[0])
131 .cloned()
132 .unwrap_or_else(|| input.clone()),
133 _ => {
134 let mut merged = serde_json::Map::new();
135 for pred_id in preds {
136 if let Some(val) = outputs.get(pred_id.as_str()) {
137 let json_val =
138 serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
139 merged.insert(pred_id.clone(), json_val);
140 }
141 }
142 Value::json(serde_json::Value::Object(merged))
143 }
144 };
145
146 let start = std::time::Instant::now();
147
148 let library_state: Option<Arc<Value>> = if meta.kind == FilterKind::Trainable {
153 None
154 } else {
155 filters.get_state(node_id)
156 };
157 let empty_state = Value::Empty;
158 let state: Cow<Value> = if meta.kind == FilterKind::Trainable {
159 let data_hash =
160 CacheKey::hash_data(&serde_json::to_vec(&node_input).unwrap_or_default());
161 let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
162
163 let s = if let Some(cached) = cache.get(&state_key)? {
164 cached
165 } else {
166 let fitted = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
167 filter.fit(&node_input, y)
168 }))
169 .map_err(|panic| {
170 let msg = panic
171 .downcast_ref::<String>()
172 .map(|s| s.as_str())
173 .or_else(|| panic.downcast_ref::<&str>().copied())
174 .unwrap_or("unknown panic");
175 SomaError::Execution {
176 node_id: node_id.clone(),
177 message: format!("fit panicked: {msg}"),
178 }
179 })??;
180 let _ = cache.put(&state_key, &fitted);
181 fitted
182 };
183 trained_states.insert(node_id.clone(), s.clone());
184 Cow::Owned(s)
185 } else {
186 match library_state.as_deref() {
187 Some(v) => Cow::Borrowed(v),
188 None => Cow::Borrowed(&empty_state),
189 }
190 };
191
192 let output = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
194 filter.forward(&node_input, &state)
195 }))
196 .map_err(|panic| {
197 let msg = panic
198 .downcast_ref::<String>()
199 .map(|s| s.as_str())
200 .or_else(|| panic.downcast_ref::<&str>().copied())
201 .unwrap_or("unknown panic");
202 SomaError::Execution {
203 node_id: node_id.clone(),
204 message: format!("forward panicked: {msg}"),
205 }
206 })??;
207
208 event_bus.emit(Event::NodeCompleted {
209 run_id: run_id.clone(),
210 node_id: node_id.to_string(),
211 duration: start.elapsed(),
212 output_summary: format!("{output}"),
213 });
214
215 outputs.insert(node_id.clone(), output);
216 }
217
218 let last_output = outputs.values().last().cloned().unwrap_or(Value::Empty);
219
220 for (id, state) in &trained_states {
223 outputs.insert(format!("__state_{id}"), state.clone());
224 }
225 Ok((last_output, outputs))
226 }
227
228 fn forward(
229 &self,
230 plan: &ExecutionPlan,
231 filters: &FilterLibrary,
232 cache: &dyn CacheStore,
233 event_bus: &Arc<EventBus>,
234 input: &Value,
235 ) -> Result<Value> {
236 let node_ids = plan.node_ids();
237 let graph_info = GraphInfo::for_linear(&node_ids);
238
239 let mut ctx =
240 Context::new(event_bus.clone(), timestamp_id("forward")).with_graph_info(graph_info);
241
242 if let Some(first) = node_ids.first() {
244 ctx.set(format!("__input_{first}"), input.clone());
245 }
246 ctx.set("__input__", input.clone());
247
248 plan.execute(&mut ctx, filters, cache)?;
249
250 ctx.execution_order
252 .last()
253 .and_then(|id| ctx.store.remove(id))
254 .and_then(|vv| vv.as_value().cloned())
255 .ok_or_else(|| SomaError::Other("no output produced".into()))
256 }
257}