Skip to main content

somatize_runtime/runner/
local.rs

1//! LocalRunner — executes plans locally using the Executor.
2//!
3//! This is the default runner. The worker's RemoteRunner delegates here
4//! after preparing the environment (deserializing filters, resolving input).
5
6use super::Runner;
7use crate::EventBus;
8use crate::executor::{Context, Executable, GraphInfo};
9use crate::filter_library::FilterLibrary;
10
11use somatize_compiler::ExecutionPlan;
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::util::timestamp_id;
17use somatize_core::value::Value;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21/// Executes plans locally — same logic for local and remote execution.
22pub struct LocalRunner;
23
24impl Runner for LocalRunner {
25    fn fit(
26        &self,
27        plan: &ExecutionPlan,
28        filters: &FilterLibrary,
29        cache: &dyn CacheStore,
30        event_bus: &Arc<EventBus>,
31        input: &Value,
32        y: Option<&Value>,
33    ) -> Result<(Value, HashMap<String, Value>)> {
34        let node_id_refs = plan.node_ids();
35        let node_ids: Vec<String> = node_id_refs.iter().map(|s| s.to_string()).collect();
36        let graph_info = GraphInfo::for_linear(&node_id_refs);
37        let run_id = timestamp_id("fit");
38        let mut outputs: HashMap<String, Value> = HashMap::new();
39        let mut trained_states: HashMap<String, Value> = HashMap::new();
40
41        // Set initial input for first node
42        if let Some(first) = node_ids.first() {
43            outputs.insert(format!("__input_{first}"), input.clone());
44        }
45
46        for node_id in &node_ids {
47            let filter = filters
48                .get(node_id)
49                .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
50
51            let meta = filter.meta();
52
53            event_bus.emit(Event::NodeStarted {
54                run_id: run_id.clone(),
55                node_id: node_id.to_string(),
56                kind: meta.kind,
57            });
58
59            // Resolve input from predecessors
60            let preds = graph_info.predecessors(node_id);
61            let node_input = match preds.len() {
62                0 => input.clone(),
63                1 => outputs
64                    .get(&preds[0])
65                    .cloned()
66                    .unwrap_or_else(|| input.clone()),
67                _ => {
68                    let mut merged = serde_json::Map::new();
69                    for pred_id in preds {
70                        if let Some(val) = outputs.get(pred_id.as_str()) {
71                            let json_val =
72                                serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
73                            merged.insert(pred_id.clone(), json_val);
74                        }
75                    }
76                    Value::Json(serde_json::Value::Object(merged))
77                }
78            };
79
80            let start = std::time::Instant::now();
81
82            // Fit trainable filters
83            let state = if meta.kind == FilterKind::Trainable {
84                let data_hash =
85                    CacheKey::hash_data(&serde_json::to_vec(&node_input).unwrap_or_default());
86                let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
87
88                let s = if let Some(cached) = cache.get(&state_key)? {
89                    cached
90                } else {
91                    let fitted = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
92                        filter.fit(&node_input, y)
93                    }))
94                    .map_err(|panic| {
95                        let msg = panic
96                            .downcast_ref::<String>()
97                            .map(|s| s.as_str())
98                            .or_else(|| panic.downcast_ref::<&str>().copied())
99                            .unwrap_or("unknown panic");
100                        SomaError::Execution {
101                            node_id: node_id.clone(),
102                            message: format!("fit panicked: {msg}"),
103                        }
104                    })??;
105                    let _ = cache.put(&state_key, &fitted);
106                    fitted
107                };
108                trained_states.insert(node_id.clone(), s.clone());
109                s
110            } else {
111                filters.get_state(node_id).cloned().unwrap_or(Value::Empty)
112            };
113
114            // Forward with state (catch panics)
115            let output = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
116                filter.forward(&node_input, &state)
117            }))
118            .map_err(|panic| {
119                let msg = panic
120                    .downcast_ref::<String>()
121                    .map(|s| s.as_str())
122                    .or_else(|| panic.downcast_ref::<&str>().copied())
123                    .unwrap_or("unknown panic");
124                SomaError::Execution {
125                    node_id: node_id.clone(),
126                    message: format!("forward panicked: {msg}"),
127                }
128            })??;
129
130            event_bus.emit(Event::NodeCompleted {
131                run_id: run_id.clone(),
132                node_id: node_id.to_string(),
133                duration: start.elapsed(),
134                output_summary: format!("{output}"),
135            });
136
137            outputs.insert(node_id.clone(), output);
138        }
139
140        let last_output = outputs.values().last().cloned().unwrap_or(Value::Empty);
141
142        // Forward outputs keyed by node_id (for GraphSession inspection).
143        // Trained states added with __state_ prefix (for Worker to extract).
144        for (id, state) in &trained_states {
145            outputs.insert(format!("__state_{id}"), state.clone());
146        }
147        Ok((last_output, outputs))
148    }
149
150    fn forward(
151        &self,
152        plan: &ExecutionPlan,
153        filters: &FilterLibrary,
154        cache: &dyn CacheStore,
155        event_bus: &Arc<EventBus>,
156        input: &Value,
157    ) -> Result<Value> {
158        let node_ids = plan.node_ids();
159        let graph_info = GraphInfo::for_linear(&node_ids);
160
161        let mut ctx =
162            Context::new(event_bus.clone(), timestamp_id("forward")).with_graph_info(graph_info);
163
164        // Set input for root nodes
165        if let Some(first) = node_ids.first() {
166            ctx.set(format!("__input_{first}"), input.clone());
167        }
168        ctx.set("__input__", input.clone());
169
170        plan.execute(&mut ctx, filters, cache)?;
171
172        // Return last executed node's output
173        ctx.execution_order
174            .last()
175            .and_then(|id| ctx.store.remove(id))
176            .and_then(|vv| vv.as_value().cloned())
177            .ok_or_else(|| SomaError::Other("no output produced".into()))
178    }
179}