Skip to main content

somatize_runtime/runner/
local.rs

1//! LocalRunner — executes plans locally using the Executor.
2//!
3//! This is the default runner. The worker's RemoteRunner delegates here
4//! after preparing the environment (deserializing filters, resolving input).
5
6use super::Runner;
7use crate::EventBus;
8use crate::executor::{Context, Executable, GraphInfo};
9use crate::filter_library::FilterLibrary;
10
11use somatize_compiler::ExecutionPlan;
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::util::timestamp_id;
17use somatize_core::value::Value;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21/// Executes plans locally — same logic for local and remote execution.
22pub struct LocalRunner;
23
24impl LocalRunner {
25    /// Fit a Sequence plan, handling Composite steps as blocks.
26    fn fit_sequence(
27        &self,
28        steps: &[ExecutionPlan],
29        filters: &FilterLibrary,
30        cache: &dyn CacheStore,
31        event_bus: &Arc<EventBus>,
32        input: &Value,
33        y: Option<&Value>,
34    ) -> Result<(Value, HashMap<String, Value>)> {
35        let mut current_input = input.clone();
36        let mut all_outputs = HashMap::new();
37
38        for step in steps {
39            let handled = if let ExecutionPlan::Composite { node_ids } = step
40                && let Some(filter) = filters.get(&node_ids[0])
41                && let Some(result) = filter.composite_fit(node_ids, &current_input, y)
42            {
43                let (output, states) = result?;
44                for (id, state) in &states {
45                    all_outputs.insert(format!("__state_{id}"), state.clone());
46                }
47                if let Some(last_id) = node_ids.last() {
48                    all_outputs.insert(last_id.clone(), output.clone());
49                }
50                current_input = output;
51                true
52            } else {
53                false
54            };
55
56            if !handled {
57                let sub_result = <Self as Runner>::fit(
58                    self,
59                    step,
60                    filters,
61                    cache,
62                    event_bus,
63                    &current_input,
64                    y,
65                )?;
66                current_input = sub_result.0;
67                all_outputs.extend(sub_result.1);
68            }
69        }
70
71        Ok((current_input, all_outputs))
72    }
73}
74
75impl Runner for LocalRunner {
76    fn fit(
77        &self,
78        plan: &ExecutionPlan,
79        filters: &FilterLibrary,
80        cache: &dyn CacheStore,
81        event_bus: &Arc<EventBus>,
82        input: &Value,
83        y: Option<&Value>,
84    ) -> Result<(Value, HashMap<String, Value>)> {
85        // Handle Composite plan: delegate to composite_fit on the first filter
86        if let ExecutionPlan::Composite { node_ids } = plan
87            && let Some(filter) = filters.get(&node_ids[0])
88            && let Some(result) = filter.composite_fit(node_ids, input, y)
89        {
90            return result;
91            // Fallback: treat as sequential if composite_fit not supported
92        }
93
94        // Handle Sequence that may contain Composite steps
95        if let ExecutionPlan::Sequence(steps) = plan {
96            return self.fit_sequence(steps, filters, cache, event_bus, input, y);
97        }
98
99        // Single node or other plan types: sequential fit
100        let node_id_refs = plan.node_ids();
101        let node_ids: Vec<String> = node_id_refs.iter().map(|s| s.to_string()).collect();
102        let graph_info = GraphInfo::for_linear(&node_id_refs);
103        let run_id = timestamp_id("fit");
104        let mut outputs: HashMap<String, Value> = HashMap::new();
105        let mut trained_states: HashMap<String, Value> = HashMap::new();
106
107        if let Some(first) = node_ids.first() {
108            outputs.insert(format!("__input_{first}"), input.clone());
109        }
110
111        for node_id in &node_ids {
112            let filter = filters
113                .get(node_id)
114                .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
115
116            let meta = filter.meta();
117
118            event_bus.emit(Event::NodeStarted {
119                run_id: run_id.clone(),
120                node_id: node_id.to_string(),
121                kind: meta.kind,
122            });
123
124            // Resolve input from predecessors
125            let preds = graph_info.predecessors(node_id);
126            let node_input = match preds.len() {
127                0 => input.clone(),
128                1 => outputs
129                    .get(&preds[0])
130                    .cloned()
131                    .unwrap_or_else(|| input.clone()),
132                _ => {
133                    let mut merged = serde_json::Map::new();
134                    for pred_id in preds {
135                        if let Some(val) = outputs.get(pred_id.as_str()) {
136                            let json_val =
137                                serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
138                            merged.insert(pred_id.clone(), json_val);
139                        }
140                    }
141                    Value::Json(serde_json::Value::Object(merged))
142                }
143            };
144
145            let start = std::time::Instant::now();
146
147            // Fit trainable filters
148            let state = if meta.kind == FilterKind::Trainable {
149                let data_hash =
150                    CacheKey::hash_data(&serde_json::to_vec(&node_input).unwrap_or_default());
151                let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
152
153                let s = if let Some(cached) = cache.get(&state_key)? {
154                    cached
155                } else {
156                    let fitted = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
157                        filter.fit(&node_input, y)
158                    }))
159                    .map_err(|panic| {
160                        let msg = panic
161                            .downcast_ref::<String>()
162                            .map(|s| s.as_str())
163                            .or_else(|| panic.downcast_ref::<&str>().copied())
164                            .unwrap_or("unknown panic");
165                        SomaError::Execution {
166                            node_id: node_id.clone(),
167                            message: format!("fit panicked: {msg}"),
168                        }
169                    })??;
170                    let _ = cache.put(&state_key, &fitted);
171                    fitted
172                };
173                trained_states.insert(node_id.clone(), s.clone());
174                s
175            } else {
176                filters.get_state(node_id).cloned().unwrap_or(Value::Empty)
177            };
178
179            // Forward with state (catch panics)
180            let output = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
181                filter.forward(&node_input, &state)
182            }))
183            .map_err(|panic| {
184                let msg = panic
185                    .downcast_ref::<String>()
186                    .map(|s| s.as_str())
187                    .or_else(|| panic.downcast_ref::<&str>().copied())
188                    .unwrap_or("unknown panic");
189                SomaError::Execution {
190                    node_id: node_id.clone(),
191                    message: format!("forward panicked: {msg}"),
192                }
193            })??;
194
195            event_bus.emit(Event::NodeCompleted {
196                run_id: run_id.clone(),
197                node_id: node_id.to_string(),
198                duration: start.elapsed(),
199                output_summary: format!("{output}"),
200            });
201
202            outputs.insert(node_id.clone(), output);
203        }
204
205        let last_output = outputs.values().last().cloned().unwrap_or(Value::Empty);
206
207        // Forward outputs keyed by node_id (for GraphSession inspection).
208        // Trained states added with __state_ prefix (for Worker to extract).
209        for (id, state) in &trained_states {
210            outputs.insert(format!("__state_{id}"), state.clone());
211        }
212        Ok((last_output, outputs))
213    }
214
215    fn forward(
216        &self,
217        plan: &ExecutionPlan,
218        filters: &FilterLibrary,
219        cache: &dyn CacheStore,
220        event_bus: &Arc<EventBus>,
221        input: &Value,
222    ) -> Result<Value> {
223        let node_ids = plan.node_ids();
224        let graph_info = GraphInfo::for_linear(&node_ids);
225
226        let mut ctx =
227            Context::new(event_bus.clone(), timestamp_id("forward")).with_graph_info(graph_info);
228
229        // Set input for root nodes
230        if let Some(first) = node_ids.first() {
231            ctx.set(format!("__input_{first}"), input.clone());
232        }
233        ctx.set("__input__", input.clone());
234
235        plan.execute(&mut ctx, filters, cache)?;
236
237        // Return last executed node's output
238        ctx.execution_order
239            .last()
240            .and_then(|id| ctx.store.remove(id))
241            .and_then(|vv| vv.as_value().cloned())
242            .ok_or_else(|| SomaError::Other("no output produced".into()))
243    }
244}