Skip to main content

somatize_runtime/runner/
local.rs

1//! LocalRunner — executes plans locally using the Executor.
2//!
3//! This is the default runner. The worker's RemoteRunner delegates here
4//! after preparing the environment (deserializing filters, resolving input).
5
6use super::Runner;
7use crate::EventBus;
8use crate::executor::{Context, Executable, GraphInfo};
9use crate::filter_library::FilterLibrary;
10
11use somatize_compiler::ExecutionPlan;
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::util::timestamp_id;
17use somatize_core::value::Value;
18use std::borrow::Cow;
19use std::collections::HashMap;
20use std::sync::Arc;
21
22/// Executes plans locally — same logic for local and remote execution.
23pub struct LocalRunner;
24
25impl LocalRunner {
26    /// Fit a Sequence plan, handling Composite steps as blocks.
27    fn fit_sequence(
28        &self,
29        steps: &[ExecutionPlan],
30        filters: &FilterLibrary,
31        cache: &dyn CacheStore,
32        event_bus: &Arc<EventBus>,
33        input: &Value,
34        y: Option<&Value>,
35    ) -> Result<(Value, HashMap<String, Value>)> {
36        let mut current_input = input.clone();
37        let mut all_outputs = HashMap::new();
38
39        for step in steps {
40            let handled = if let ExecutionPlan::Composite { node_ids } = step
41                && let Some(filter) = filters.get(&node_ids[0])
42                && let Some(result) = filter.composite_fit(node_ids, &current_input, y)
43            {
44                let (output, states) = result?;
45                for (id, state) in &states {
46                    all_outputs.insert(format!("__state_{id}"), state.clone());
47                }
48                if let Some(last_id) = node_ids.last() {
49                    all_outputs.insert(last_id.clone(), output.clone());
50                }
51                current_input = output;
52                true
53            } else {
54                false
55            };
56
57            if !handled {
58                let sub_result = <Self as Runner>::fit(
59                    self,
60                    step,
61                    filters,
62                    cache,
63                    event_bus,
64                    &current_input,
65                    y,
66                )?;
67                current_input = sub_result.0;
68                all_outputs.extend(sub_result.1);
69            }
70        }
71
72        Ok((current_input, all_outputs))
73    }
74}
75
76impl Runner for LocalRunner {
77    fn fit(
78        &self,
79        plan: &ExecutionPlan,
80        filters: &FilterLibrary,
81        cache: &dyn CacheStore,
82        event_bus: &Arc<EventBus>,
83        input: &Value,
84        y: Option<&Value>,
85    ) -> Result<(Value, HashMap<String, Value>)> {
86        // Handle Composite plan: delegate to composite_fit on the first filter
87        if let ExecutionPlan::Composite { node_ids } = plan
88            && let Some(filter) = filters.get(&node_ids[0])
89            && let Some(result) = filter.composite_fit(node_ids, input, y)
90        {
91            return result;
92            // Fallback: treat as sequential if composite_fit not supported
93        }
94
95        // Handle Sequence that may contain Composite steps
96        if let ExecutionPlan::Sequence(steps) = plan {
97            return self.fit_sequence(steps, filters, cache, event_bus, input, y);
98        }
99
100        // Single node or other plan types: sequential fit
101        let node_id_refs = plan.node_ids();
102        let node_ids: Vec<String> = node_id_refs.iter().map(|s| s.to_string()).collect();
103        let graph_info = GraphInfo::for_linear(&node_id_refs);
104        let run_id = timestamp_id("fit");
105        let mut outputs: HashMap<String, Value> = HashMap::new();
106        let mut trained_states: HashMap<String, Value> = HashMap::new();
107
108        if let Some(first) = node_ids.first() {
109            outputs.insert(format!("__input_{first}"), input.clone());
110        }
111
112        for node_id in &node_ids {
113            let filter = filters
114                .get(node_id)
115                .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
116
117            let meta = filter.meta();
118
119            event_bus.emit(Event::NodeStarted {
120                run_id: run_id.clone(),
121                node_id: node_id.to_string(),
122                kind: meta.kind,
123            });
124
125            // Resolve input from predecessors
126            let preds = graph_info.predecessors(node_id);
127            let node_input = match preds.len() {
128                0 => input.clone(),
129                1 => outputs
130                    .get(&preds[0])
131                    .cloned()
132                    .unwrap_or_else(|| input.clone()),
133                _ => {
134                    let mut merged = serde_json::Map::new();
135                    for pred_id in preds {
136                        if let Some(val) = outputs.get(pred_id.as_str()) {
137                            let json_val =
138                                serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
139                            merged.insert(pred_id.clone(), json_val);
140                        }
141                    }
142                    Value::json(serde_json::Value::Object(merged))
143                }
144            };
145
146            let start = std::time::Instant::now();
147
148            // Fit trainable filters. Use Cow so the non-trainable branch
149            // can borrow the existing state (via the StateStore Arc)
150            // without cloning potentially huge tensors on every forward
151            // call.
152            let library_state: Option<Arc<Value>> = if meta.kind == FilterKind::Trainable {
153                None
154            } else {
155                filters.get_state(node_id)
156            };
157            let empty_state = Value::Empty;
158            let state: Cow<Value> = if meta.kind == FilterKind::Trainable {
159                let data_hash =
160                    CacheKey::hash_data(&serde_json::to_vec(&node_input).unwrap_or_default());
161                let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
162
163                let s = if let Some(cached) = cache.get(&state_key)? {
164                    cached
165                } else {
166                    let fitted = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
167                        filter.fit(&node_input, y)
168                    }))
169                    .map_err(|panic| {
170                        let msg = panic
171                            .downcast_ref::<String>()
172                            .map(|s| s.as_str())
173                            .or_else(|| panic.downcast_ref::<&str>().copied())
174                            .unwrap_or("unknown panic");
175                        SomaError::Execution {
176                            node_id: node_id.clone(),
177                            message: format!("fit panicked: {msg}"),
178                        }
179                    })??;
180                    let _ = cache.put(&state_key, &fitted);
181                    fitted
182                };
183                trained_states.insert(node_id.clone(), s.clone());
184                Cow::Owned(s)
185            } else {
186                match library_state.as_deref() {
187                    Some(v) => Cow::Borrowed(v),
188                    None => Cow::Borrowed(&empty_state),
189                }
190            };
191
192            // Forward with state (catch panics)
193            let output = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
194                filter.forward(&node_input, &state)
195            }))
196            .map_err(|panic| {
197                let msg = panic
198                    .downcast_ref::<String>()
199                    .map(|s| s.as_str())
200                    .or_else(|| panic.downcast_ref::<&str>().copied())
201                    .unwrap_or("unknown panic");
202                SomaError::Execution {
203                    node_id: node_id.clone(),
204                    message: format!("forward panicked: {msg}"),
205                }
206            })??;
207
208            event_bus.emit(Event::NodeCompleted {
209                run_id: run_id.clone(),
210                node_id: node_id.to_string(),
211                duration: start.elapsed(),
212                output_summary: format!("{output}"),
213            });
214
215            outputs.insert(node_id.clone(), output);
216        }
217
218        let last_output = outputs.values().last().cloned().unwrap_or(Value::Empty);
219
220        // Forward outputs keyed by node_id (for GraphSession inspection).
221        // Trained states added with __state_ prefix (for Worker to extract).
222        for (id, state) in &trained_states {
223            outputs.insert(format!("__state_{id}"), state.clone());
224        }
225        Ok((last_output, outputs))
226    }
227
228    fn forward(
229        &self,
230        plan: &ExecutionPlan,
231        filters: &FilterLibrary,
232        cache: &dyn CacheStore,
233        event_bus: &Arc<EventBus>,
234        input: &Value,
235    ) -> Result<Value> {
236        let node_ids = plan.node_ids();
237        let graph_info = GraphInfo::for_linear(&node_ids);
238
239        let mut ctx =
240            Context::new(event_bus.clone(), timestamp_id("forward")).with_graph_info(graph_info);
241
242        // Set input for root nodes
243        if let Some(first) = node_ids.first() {
244            ctx.set(format!("__input_{first}"), input.clone());
245        }
246        ctx.set("__input__", input.clone());
247
248        plan.execute(&mut ctx, filters, cache)?;
249
250        // Return last executed node's output
251        ctx.execution_order
252            .last()
253            .and_then(|id| ctx.store.remove(id))
254            .and_then(|vv| vv.as_value().cloned())
255            .ok_or_else(|| SomaError::Other("no output produced".into()))
256    }
257}