Skip to main content

somatize_runtime/runner/
local.rs

1//! LocalRunner — executes plans locally using the Executor.
2//!
3//! This is the default runner. The worker's RemoteRunner delegates here
4//! after preparing the environment (deserializing filters, resolving input).
5
6use super::Runner;
7use crate::EventBus;
8use crate::executor::{Context, Executable, GraphInfo};
9use crate::filter_library::FilterLibrary;
10
11use somatize_compiler::ExecutionPlan;
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::util::timestamp_id;
17use somatize_core::value::Value;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21/// Executes plans locally — same logic for local and remote execution.
22pub struct LocalRunner;
23
24impl Runner for LocalRunner {
25    fn fit(
26        &self,
27        plan: &ExecutionPlan,
28        filters: &FilterLibrary,
29        cache: &dyn CacheStore,
30        event_bus: &Arc<EventBus>,
31        input: &Value,
32        y: Option<&Value>,
33    ) -> Result<(Value, HashMap<String, Value>)> {
34        let node_id_refs = plan.node_ids();
35        let node_ids: Vec<String> = node_id_refs.iter().map(|s| s.to_string()).collect();
36        let graph_info = GraphInfo::for_linear(&node_id_refs);
37        let run_id = timestamp_id("fit");
38        let mut outputs: HashMap<String, Value> = HashMap::new();
39        let mut trained_states: HashMap<String, Value> = HashMap::new();
40
41        // Set initial input for first node
42        if let Some(first) = node_ids.first() {
43            outputs.insert(format!("__input_{first}"), input.clone());
44        }
45
46        for node_id in &node_ids {
47            let filter = filters
48                .get(node_id)
49                .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
50
51            let meta = filter.meta();
52
53            event_bus.emit(Event::NodeStarted {
54                run_id: run_id.clone(),
55                node_id: node_id.to_string(),
56                kind: meta.kind,
57            });
58
59            // Resolve input from predecessors
60            let preds = graph_info.predecessors(node_id);
61            let node_input = match preds.len() {
62                0 => input.clone(),
63                1 => outputs
64                    .get(&preds[0])
65                    .cloned()
66                    .unwrap_or_else(|| input.clone()),
67                _ => {
68                    let mut merged = serde_json::Map::new();
69                    for pred_id in preds {
70                        if let Some(val) = outputs.get(pred_id.as_str()) {
71                            let json_val =
72                                serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
73                            merged.insert(pred_id.clone(), json_val);
74                        }
75                    }
76                    Value::Json(serde_json::Value::Object(merged))
77                }
78            };
79
80            let start = std::time::Instant::now();
81
82            // Fit trainable filters
83            let state = if meta.kind == FilterKind::Trainable {
84                let data_hash =
85                    CacheKey::hash_data(&serde_json::to_vec(&node_input).unwrap_or_default());
86                let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
87
88                let s = if let Some(cached) = cache.get(&state_key)? {
89                    cached
90                } else {
91                    let fitted = filter.fit(&node_input, y)?;
92                    let _ = cache.put(&state_key, &fitted);
93                    fitted
94                };
95                trained_states.insert(node_id.clone(), s.clone());
96                s
97            } else {
98                filters.get_state(node_id).cloned().unwrap_or(Value::Empty)
99            };
100
101            // Forward with state
102            let output = filter.forward(&node_input, &state)?;
103
104            event_bus.emit(Event::NodeCompleted {
105                run_id: run_id.clone(),
106                node_id: node_id.to_string(),
107                duration: start.elapsed(),
108                output_summary: format!("{output}"),
109            });
110
111            outputs.insert(node_id.clone(), output);
112        }
113
114        let last_output = outputs.values().last().cloned().unwrap_or(Value::Empty);
115
116        Ok((last_output, outputs))
117    }
118
119    fn forward(
120        &self,
121        plan: &ExecutionPlan,
122        filters: &FilterLibrary,
123        cache: &dyn CacheStore,
124        event_bus: &Arc<EventBus>,
125        input: &Value,
126    ) -> Result<Value> {
127        let node_ids = plan.node_ids();
128        let graph_info = GraphInfo::for_linear(&node_ids);
129
130        let mut ctx =
131            Context::new(event_bus.clone(), timestamp_id("forward")).with_graph_info(graph_info);
132
133        // Set input for root nodes
134        if let Some(first) = node_ids.first() {
135            ctx.set(format!("__input_{first}"), input.clone());
136        }
137        ctx.set("__input__", input.clone());
138
139        plan.execute(&mut ctx, filters, cache)?;
140
141        // Return last executed node's output
142        ctx.execution_order
143            .last()
144            .and_then(|id| ctx.store.remove(id))
145            .and_then(|vv| vv.as_value().cloned())
146            .ok_or_else(|| SomaError::Other("no output produced".into()))
147    }
148}