1use super::Runner;
7use crate::EventBus;
8use crate::executor::{Context, Executable, GraphInfo};
9use crate::filter_library::FilterLibrary;
10
11use somatize_compiler::ExecutionPlan;
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::{Filter, FilterKind};
16use somatize_core::util::timestamp_id;
17use somatize_core::value::Value;
18use std::borrow::Cow;
19use std::collections::HashMap;
20use std::sync::Arc;
21
22fn collect_peers(
26 filters: &FilterLibrary,
27 node_ids: &[String],
28) -> Option<Vec<(String, Arc<dyn Filter>)>> {
29 node_ids
30 .iter()
31 .map(|id| filters.get(id).map(|f| (id.clone(), f)))
32 .collect()
33}
34
35pub struct LocalRunner;
37
38impl LocalRunner {
39 fn fit_sequence(
41 &self,
42 steps: &[ExecutionPlan],
43 filters: &FilterLibrary,
44 cache: &dyn CacheStore,
45 event_bus: &Arc<EventBus>,
46 input: &Value,
47 y: Option<&Value>,
48 ) -> Result<(Value, HashMap<String, Value>)> {
49 let mut current_input = input.clone();
50 let mut all_outputs = HashMap::new();
51
52 for step in steps {
53 let handled = if let ExecutionPlan::Composite { node_ids } = step
54 && let Some(peers) = collect_peers(filters, node_ids)
55 && let Some(filter) = filters.get(&node_ids[0])
56 && let Some(result) = filter.composite_fit(&peers, ¤t_input, y)
57 {
58 let (output, states) = result?;
59 for (id, state) in &states {
60 all_outputs.insert(format!("__state_{id}"), state.clone());
61 }
62 if let Some(last_id) = node_ids.last() {
63 all_outputs.insert(last_id.clone(), output.clone());
64 }
65 current_input = output;
66 true
67 } else {
68 false
69 };
70
71 if !handled {
72 let sub_result = <Self as Runner>::fit(
73 self,
74 step,
75 filters,
76 cache,
77 event_bus,
78 ¤t_input,
79 y,
80 )?;
81 current_input = sub_result.0;
82 all_outputs.extend(sub_result.1);
83 }
84 }
85
86 Ok((current_input, all_outputs))
87 }
88}
89
90impl Runner for LocalRunner {
91 fn fit(
92 &self,
93 plan: &ExecutionPlan,
94 filters: &FilterLibrary,
95 cache: &dyn CacheStore,
96 event_bus: &Arc<EventBus>,
97 input: &Value,
98 y: Option<&Value>,
99 ) -> Result<(Value, HashMap<String, Value>)> {
100 if let ExecutionPlan::Composite { node_ids } = plan
102 && let Some(peers) = collect_peers(filters, node_ids)
103 && let Some(filter) = filters.get(&node_ids[0])
104 && let Some(result) = filter.composite_fit(&peers, input, y)
105 {
106 return result;
107 }
109
110 if let ExecutionPlan::Sequence(steps) = plan {
112 return self.fit_sequence(steps, filters, cache, event_bus, input, y);
113 }
114
115 let node_id_refs = plan.node_ids();
117 let node_ids: Vec<String> = node_id_refs.iter().map(|s| s.to_string()).collect();
118 let graph_info = GraphInfo::for_linear(&node_id_refs);
119 let run_id = timestamp_id("fit");
120 let mut outputs: HashMap<String, Value> = HashMap::new();
121 let mut trained_states: HashMap<String, Value> = HashMap::new();
122
123 if let Some(first) = node_ids.first() {
124 outputs.insert(format!("__input_{first}"), input.clone());
125 }
126
127 for node_id in &node_ids {
128 let filter = filters
129 .get(node_id)
130 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
131
132 let meta = filter.meta();
133
134 event_bus.emit(Event::NodeStarted {
135 run_id: run_id.clone(),
136 node_id: node_id.to_string(),
137 kind: meta.kind,
138 });
139
140 let preds = graph_info.predecessors(node_id);
142 let node_input = match preds.len() {
143 0 => input.clone(),
144 1 => outputs
145 .get(&preds[0])
146 .cloned()
147 .unwrap_or_else(|| input.clone()),
148 _ => {
149 let mut merged = serde_json::Map::new();
150 for pred_id in preds {
151 if let Some(val) = outputs.get(pred_id.as_str()) {
152 let json_val =
153 serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
154 merged.insert(pred_id.clone(), json_val);
155 }
156 }
157 Value::json(serde_json::Value::Object(merged))
158 }
159 };
160
161 let start = std::time::Instant::now();
162
163 let library_state: Option<Arc<Value>> = if meta.kind == FilterKind::Trainable {
168 None
169 } else {
170 filters.get_state(node_id)
171 };
172 let empty_state = Value::Empty;
173 let state: Cow<Value> = if meta.kind == FilterKind::Trainable {
174 let data_hash =
175 CacheKey::hash_data(&serde_json::to_vec(&node_input).unwrap_or_default());
176 let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
177
178 let s = if let Some(cached) = cache.get(&state_key)? {
179 cached
180 } else {
181 let fitted = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
182 filter.fit(&node_input, y)
183 }))
184 .map_err(|panic| {
185 let msg = panic
186 .downcast_ref::<String>()
187 .map(|s| s.as_str())
188 .or_else(|| panic.downcast_ref::<&str>().copied())
189 .unwrap_or("unknown panic");
190 SomaError::Execution {
191 node_id: node_id.clone(),
192 message: format!("fit panicked: {msg}"),
193 }
194 })??;
195 let _ = cache.put(&state_key, &fitted);
196 fitted
197 };
198 trained_states.insert(node_id.clone(), s.clone());
199 Cow::Owned(s)
200 } else {
201 match library_state.as_deref() {
202 Some(v) => Cow::Borrowed(v),
203 None => Cow::Borrowed(&empty_state),
204 }
205 };
206
207 let output = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
209 filter.forward(&node_input, &state)
210 }))
211 .map_err(|panic| {
212 let msg = panic
213 .downcast_ref::<String>()
214 .map(|s| s.as_str())
215 .or_else(|| panic.downcast_ref::<&str>().copied())
216 .unwrap_or("unknown panic");
217 SomaError::Execution {
218 node_id: node_id.clone(),
219 message: format!("forward panicked: {msg}"),
220 }
221 })??;
222
223 event_bus.emit(Event::NodeCompleted {
224 run_id: run_id.clone(),
225 node_id: node_id.to_string(),
226 duration: start.elapsed(),
227 output_summary: format!("{output}"),
228 });
229
230 outputs.insert(node_id.clone(), output);
231 }
232
233 let last_output = outputs.values().last().cloned().unwrap_or(Value::Empty);
234
235 for (id, state) in &trained_states {
238 outputs.insert(format!("__state_{id}"), state.clone());
239 }
240 Ok((last_output, outputs))
241 }
242
243 fn forward(
244 &self,
245 plan: &ExecutionPlan,
246 filters: &FilterLibrary,
247 cache: &dyn CacheStore,
248 event_bus: &Arc<EventBus>,
249 input: &Value,
250 ) -> Result<Value> {
251 let node_ids = plan.node_ids();
252 let graph_info = GraphInfo::for_linear(&node_ids);
253
254 let mut ctx =
255 Context::new(event_bus.clone(), timestamp_id("forward")).with_graph_info(graph_info);
256
257 if let Some(first) = node_ids.first() {
259 ctx.set(format!("__input_{first}"), input.clone());
260 }
261 ctx.set("__input__", input.clone());
262
263 plan.execute(&mut ctx, filters, cache)?;
264
265 ctx.execution_order
267 .last()
268 .and_then(|id| ctx.store.remove(id))
269 .and_then(|vv| vv.as_value().cloned())
270 .ok_or_else(|| SomaError::Other("no output produced".into()))
271 }
272}