Skip to main content

somatize_runtime/runner/
local.rs

1//! LocalRunner — executes plans locally using the Executor.
2//!
3//! This is the default runner. The worker's RemoteRunner delegates here
4//! after preparing the environment (deserializing filters, resolving input).
5
6use super::Runner;
7use crate::EventBus;
8use crate::executor::{Context, Executable, GraphInfo};
9use crate::filter_library::FilterLibrary;
10
11use somatize_compiler::ExecutionPlan;
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::{Filter, FilterKind};
16use somatize_core::util::timestamp_id;
17use somatize_core::value::Value;
18use std::borrow::Cow;
19use std::collections::HashMap;
20use std::sync::Arc;
21
22/// Resolve every node_id in ``node_ids`` to its filter instance.
23/// Returns ``None`` if any id is missing from the library — in that case
24/// the runner falls back to sequential execution with a NodeNotFound error.
25fn collect_peers(
26    filters: &FilterLibrary,
27    node_ids: &[String],
28) -> Option<Vec<(String, Arc<dyn Filter>)>> {
29    node_ids
30        .iter()
31        .map(|id| filters.get(id).map(|f| (id.clone(), f)))
32        .collect()
33}
34
35/// Executes plans locally — same logic for local and remote execution.
36pub struct LocalRunner;
37
38impl LocalRunner {
39    /// Fit a Sequence plan, handling Composite steps as blocks.
40    fn fit_sequence(
41        &self,
42        steps: &[ExecutionPlan],
43        filters: &FilterLibrary,
44        cache: &dyn CacheStore,
45        event_bus: &Arc<EventBus>,
46        input: &Value,
47        y: Option<&Value>,
48    ) -> Result<(Value, HashMap<String, Value>)> {
49        let mut current_input = input.clone();
50        let mut all_outputs = HashMap::new();
51
52        for step in steps {
53            let handled = if let ExecutionPlan::Composite { node_ids } = step
54                && let Some(peers) = collect_peers(filters, node_ids)
55                && let Some(filter) = filters.get(&node_ids[0])
56                && let Some(result) = filter.composite_fit(&peers, &current_input, y)
57            {
58                let (output, states) = result?;
59                for (id, state) in &states {
60                    all_outputs.insert(format!("__state_{id}"), state.clone());
61                }
62                if let Some(last_id) = node_ids.last() {
63                    all_outputs.insert(last_id.clone(), output.clone());
64                }
65                current_input = output;
66                true
67            } else {
68                false
69            };
70
71            if !handled {
72                let sub_result = <Self as Runner>::fit(
73                    self,
74                    step,
75                    filters,
76                    cache,
77                    event_bus,
78                    &current_input,
79                    y,
80                )?;
81                current_input = sub_result.0;
82                all_outputs.extend(sub_result.1);
83            }
84        }
85
86        Ok((current_input, all_outputs))
87    }
88}
89
90impl Runner for LocalRunner {
91    fn fit(
92        &self,
93        plan: &ExecutionPlan,
94        filters: &FilterLibrary,
95        cache: &dyn CacheStore,
96        event_bus: &Arc<EventBus>,
97        input: &Value,
98        y: Option<&Value>,
99    ) -> Result<(Value, HashMap<String, Value>)> {
100        // Handle Composite plan: delegate to composite_fit on the first filter
101        if let ExecutionPlan::Composite { node_ids } = plan
102            && let Some(peers) = collect_peers(filters, node_ids)
103            && let Some(filter) = filters.get(&node_ids[0])
104            && let Some(result) = filter.composite_fit(&peers, input, y)
105        {
106            return result;
107            // Fallback: treat as sequential if composite_fit not supported
108        }
109
110        // Handle Sequence that may contain Composite steps
111        if let ExecutionPlan::Sequence(steps) = plan {
112            return self.fit_sequence(steps, filters, cache, event_bus, input, y);
113        }
114
115        // Single node or other plan types: sequential fit
116        let node_id_refs = plan.node_ids();
117        let node_ids: Vec<String> = node_id_refs.iter().map(|s| s.to_string()).collect();
118        let graph_info = GraphInfo::for_linear(&node_id_refs);
119        let run_id = timestamp_id("fit");
120        let mut outputs: HashMap<String, Value> = HashMap::new();
121        let mut trained_states: HashMap<String, Value> = HashMap::new();
122
123        if let Some(first) = node_ids.first() {
124            outputs.insert(format!("__input_{first}"), input.clone());
125        }
126
127        for node_id in &node_ids {
128            let filter = filters
129                .get(node_id)
130                .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
131
132            let meta = filter.meta();
133
134            event_bus.emit(Event::NodeStarted {
135                run_id: run_id.clone(),
136                node_id: node_id.to_string(),
137                kind: meta.kind,
138            });
139
140            // Resolve input from predecessors
141            let preds = graph_info.predecessors(node_id);
142            let node_input = match preds.len() {
143                0 => input.clone(),
144                1 => outputs
145                    .get(&preds[0])
146                    .cloned()
147                    .unwrap_or_else(|| input.clone()),
148                _ => {
149                    let mut merged = serde_json::Map::new();
150                    for pred_id in preds {
151                        if let Some(val) = outputs.get(pred_id.as_str()) {
152                            let json_val =
153                                serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
154                            merged.insert(pred_id.clone(), json_val);
155                        }
156                    }
157                    Value::json(serde_json::Value::Object(merged))
158                }
159            };
160
161            let start = std::time::Instant::now();
162
163            // Fit trainable filters. Use Cow so the non-trainable branch
164            // can borrow the existing state (via the StateStore Arc)
165            // without cloning potentially huge tensors on every forward
166            // call.
167            let library_state: Option<Arc<Value>> = if meta.kind == FilterKind::Trainable {
168                None
169            } else {
170                filters.get_state(node_id)
171            };
172            let empty_state = Value::Empty;
173            let state: Cow<Value> = if meta.kind == FilterKind::Trainable {
174                let data_hash =
175                    CacheKey::hash_data(&serde_json::to_vec(&node_input).unwrap_or_default());
176                let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
177
178                let s = if let Some(cached) = cache.get(&state_key)? {
179                    cached
180                } else {
181                    let fitted = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
182                        filter.fit(&node_input, y)
183                    }))
184                    .map_err(|panic| {
185                        let msg = panic
186                            .downcast_ref::<String>()
187                            .map(|s| s.as_str())
188                            .or_else(|| panic.downcast_ref::<&str>().copied())
189                            .unwrap_or("unknown panic");
190                        SomaError::Execution {
191                            node_id: node_id.clone(),
192                            message: format!("fit panicked: {msg}"),
193                        }
194                    })??;
195                    let _ = cache.put(&state_key, &fitted);
196                    fitted
197                };
198                trained_states.insert(node_id.clone(), s.clone());
199                Cow::Owned(s)
200            } else {
201                match library_state.as_deref() {
202                    Some(v) => Cow::Borrowed(v),
203                    None => Cow::Borrowed(&empty_state),
204                }
205            };
206
207            // Forward with state (catch panics)
208            let output = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
209                filter.forward(&node_input, &state)
210            }))
211            .map_err(|panic| {
212                let msg = panic
213                    .downcast_ref::<String>()
214                    .map(|s| s.as_str())
215                    .or_else(|| panic.downcast_ref::<&str>().copied())
216                    .unwrap_or("unknown panic");
217                SomaError::Execution {
218                    node_id: node_id.clone(),
219                    message: format!("forward panicked: {msg}"),
220                }
221            })??;
222
223            event_bus.emit(Event::NodeCompleted {
224                run_id: run_id.clone(),
225                node_id: node_id.to_string(),
226                duration: start.elapsed(),
227                output_summary: format!("{output}"),
228            });
229
230            outputs.insert(node_id.clone(), output);
231        }
232
233        let last_output = outputs.values().last().cloned().unwrap_or(Value::Empty);
234
235        // Forward outputs keyed by node_id (for GraphSession inspection).
236        // Trained states added with __state_ prefix (for Worker to extract).
237        for (id, state) in &trained_states {
238            outputs.insert(format!("__state_{id}"), state.clone());
239        }
240        Ok((last_output, outputs))
241    }
242
243    fn forward(
244        &self,
245        plan: &ExecutionPlan,
246        filters: &FilterLibrary,
247        cache: &dyn CacheStore,
248        event_bus: &Arc<EventBus>,
249        input: &Value,
250    ) -> Result<Value> {
251        let node_ids = plan.node_ids();
252        let graph_info = GraphInfo::for_linear(&node_ids);
253
254        let mut ctx =
255            Context::new(event_bus.clone(), timestamp_id("forward")).with_graph_info(graph_info);
256
257        // Set input for root nodes
258        if let Some(first) = node_ids.first() {
259            ctx.set(format!("__input_{first}"), input.clone());
260        }
261        ctx.set("__input__", input.clone());
262
263        plan.execute(&mut ctx, filters, cache)?;
264
265        // Return last executed node's output
266        ctx.execution_order
267            .last()
268            .and_then(|id| ctx.store.remove(id))
269            .and_then(|vv| vv.as_value().cloned())
270            .ok_or_else(|| SomaError::Other("no output produced".into()))
271    }
272}