somatize_runtime/runner/
local.rs1use super::Runner;
7use crate::EventBus;
8use crate::executor::{Context, Executable, GraphInfo};
9use crate::filter_library::FilterLibrary;
10
11use somatize_compiler::ExecutionPlan;
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::util::timestamp_id;
17use somatize_core::value::Value;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21pub struct LocalRunner;
23
24impl LocalRunner {
25 fn fit_sequence(
27 &self,
28 steps: &[ExecutionPlan],
29 filters: &FilterLibrary,
30 cache: &dyn CacheStore,
31 event_bus: &Arc<EventBus>,
32 input: &Value,
33 y: Option<&Value>,
34 ) -> Result<(Value, HashMap<String, Value>)> {
35 let mut current_input = input.clone();
36 let mut all_outputs = HashMap::new();
37
38 for step in steps {
39 let handled = if let ExecutionPlan::Composite { node_ids } = step
40 && let Some(filter) = filters.get(&node_ids[0])
41 && let Some(result) = filter.composite_fit(node_ids, ¤t_input, y)
42 {
43 let (output, states) = result?;
44 for (id, state) in &states {
45 all_outputs.insert(format!("__state_{id}"), state.clone());
46 }
47 if let Some(last_id) = node_ids.last() {
48 all_outputs.insert(last_id.clone(), output.clone());
49 }
50 current_input = output;
51 true
52 } else {
53 false
54 };
55
56 if !handled {
57 let sub_result = <Self as Runner>::fit(
58 self,
59 step,
60 filters,
61 cache,
62 event_bus,
63 ¤t_input,
64 y,
65 )?;
66 current_input = sub_result.0;
67 all_outputs.extend(sub_result.1);
68 }
69 }
70
71 Ok((current_input, all_outputs))
72 }
73}
74
75impl Runner for LocalRunner {
76 fn fit(
77 &self,
78 plan: &ExecutionPlan,
79 filters: &FilterLibrary,
80 cache: &dyn CacheStore,
81 event_bus: &Arc<EventBus>,
82 input: &Value,
83 y: Option<&Value>,
84 ) -> Result<(Value, HashMap<String, Value>)> {
85 if let ExecutionPlan::Composite { node_ids } = plan
87 && let Some(filter) = filters.get(&node_ids[0])
88 && let Some(result) = filter.composite_fit(node_ids, input, y)
89 {
90 return result;
91 }
93
94 if let ExecutionPlan::Sequence(steps) = plan {
96 return self.fit_sequence(steps, filters, cache, event_bus, input, y);
97 }
98
99 let node_id_refs = plan.node_ids();
101 let node_ids: Vec<String> = node_id_refs.iter().map(|s| s.to_string()).collect();
102 let graph_info = GraphInfo::for_linear(&node_id_refs);
103 let run_id = timestamp_id("fit");
104 let mut outputs: HashMap<String, Value> = HashMap::new();
105 let mut trained_states: HashMap<String, Value> = HashMap::new();
106
107 if let Some(first) = node_ids.first() {
108 outputs.insert(format!("__input_{first}"), input.clone());
109 }
110
111 for node_id in &node_ids {
112 let filter = filters
113 .get(node_id)
114 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
115
116 let meta = filter.meta();
117
118 event_bus.emit(Event::NodeStarted {
119 run_id: run_id.clone(),
120 node_id: node_id.to_string(),
121 kind: meta.kind,
122 });
123
124 let preds = graph_info.predecessors(node_id);
126 let node_input = match preds.len() {
127 0 => input.clone(),
128 1 => outputs
129 .get(&preds[0])
130 .cloned()
131 .unwrap_or_else(|| input.clone()),
132 _ => {
133 let mut merged = serde_json::Map::new();
134 for pred_id in preds {
135 if let Some(val) = outputs.get(pred_id.as_str()) {
136 let json_val =
137 serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
138 merged.insert(pred_id.clone(), json_val);
139 }
140 }
141 Value::Json(serde_json::Value::Object(merged))
142 }
143 };
144
145 let start = std::time::Instant::now();
146
147 let state = if meta.kind == FilterKind::Trainable {
149 let data_hash =
150 CacheKey::hash_data(&serde_json::to_vec(&node_input).unwrap_or_default());
151 let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
152
153 let s = if let Some(cached) = cache.get(&state_key)? {
154 cached
155 } else {
156 let fitted = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
157 filter.fit(&node_input, y)
158 }))
159 .map_err(|panic| {
160 let msg = panic
161 .downcast_ref::<String>()
162 .map(|s| s.as_str())
163 .or_else(|| panic.downcast_ref::<&str>().copied())
164 .unwrap_or("unknown panic");
165 SomaError::Execution {
166 node_id: node_id.clone(),
167 message: format!("fit panicked: {msg}"),
168 }
169 })??;
170 let _ = cache.put(&state_key, &fitted);
171 fitted
172 };
173 trained_states.insert(node_id.clone(), s.clone());
174 s
175 } else {
176 filters.get_state(node_id).cloned().unwrap_or(Value::Empty)
177 };
178
179 let output = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
181 filter.forward(&node_input, &state)
182 }))
183 .map_err(|panic| {
184 let msg = panic
185 .downcast_ref::<String>()
186 .map(|s| s.as_str())
187 .or_else(|| panic.downcast_ref::<&str>().copied())
188 .unwrap_or("unknown panic");
189 SomaError::Execution {
190 node_id: node_id.clone(),
191 message: format!("forward panicked: {msg}"),
192 }
193 })??;
194
195 event_bus.emit(Event::NodeCompleted {
196 run_id: run_id.clone(),
197 node_id: node_id.to_string(),
198 duration: start.elapsed(),
199 output_summary: format!("{output}"),
200 });
201
202 outputs.insert(node_id.clone(), output);
203 }
204
205 let last_output = outputs.values().last().cloned().unwrap_or(Value::Empty);
206
207 for (id, state) in &trained_states {
210 outputs.insert(format!("__state_{id}"), state.clone());
211 }
212 Ok((last_output, outputs))
213 }
214
215 fn forward(
216 &self,
217 plan: &ExecutionPlan,
218 filters: &FilterLibrary,
219 cache: &dyn CacheStore,
220 event_bus: &Arc<EventBus>,
221 input: &Value,
222 ) -> Result<Value> {
223 let node_ids = plan.node_ids();
224 let graph_info = GraphInfo::for_linear(&node_ids);
225
226 let mut ctx =
227 Context::new(event_bus.clone(), timestamp_id("forward")).with_graph_info(graph_info);
228
229 if let Some(first) = node_ids.first() {
231 ctx.set(format!("__input_{first}"), input.clone());
232 }
233 ctx.set("__input__", input.clone());
234
235 plan.execute(&mut ctx, filters, cache)?;
236
237 ctx.execution_order
239 .last()
240 .and_then(|id| ctx.store.remove(id))
241 .and_then(|vv| vv.as_value().cloned())
242 .ok_or_else(|| SomaError::Other("no output produced".into()))
243 }
244}