Skip to main content

somatize_worker/
worker.rs

1//! Worker — receives and executes plans from a coordinator.
2
3use crate::protocol::*;
4use somatize_core::cache::CacheStore;
5use somatize_core::event::Event;
6use somatize_core::filter::Filter;
7use somatize_runtime::{Context, EventBus, FilterLibrary, MemoryCache, execute};
8use std::sync::Arc;
9use std::time::Instant;
10
11/// Worker state: manages execution of plans received from a coordinator.
12pub struct Worker {
13    pub id: WorkerId,
14    pub capabilities: Capabilities,
15    event_bus: Arc<EventBus>,
16    cache: Arc<dyn CacheStore>,
17    filters: FilterLibrary,
18}
19
20impl Worker {
21    pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
22        Self {
23            id: id.into(),
24            capabilities,
25            event_bus: Arc::new(EventBus::new(256)),
26            cache: Arc::new(MemoryCache::default()),
27            filters: FilterLibrary::new(),
28        }
29    }
30
31    /// Set a custom cache store (e.g. tiered or shared).
32    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
33        self.cache = cache;
34        self
35    }
36
37    /// Register a filter that this worker can execute.
38    pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
39        self.filters.register(node_id, filter);
40    }
41
42    /// Subscribe to execution events.
43    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
44        self.event_bus.subscribe()
45    }
46
47    /// Build a registration message.
48    pub fn registration_message(&self) -> WorkerToCoordinator {
49        WorkerToCoordinator::Register {
50            worker_id: self.id.clone(),
51            capabilities: self.capabilities.clone(),
52        }
53    }
54
55    /// Execute a serialized plan.
56    pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
57        let start = Instant::now();
58
59        let mut ctx = Context::new(
60            self.event_bus.clone(),
61            format!("worker_run_{}", plan.plan_id),
62        );
63
64        // Resolve input data (inline or from DataStore)
65        if let Some(input_source) = &plan.input {
66            use crate::protocol::InputSource;
67            let input_value = match input_source {
68                InputSource::Inline { value } => value.clone(),
69                InputSource::Reference { data_ref } => {
70                    // Try to load from context's data store
71                    if let Some(store) = &ctx.data_store {
72                        store
73                            .get(data_ref)
74                            .unwrap_or(somatize_core::value::Value::Empty)
75                    } else {
76                        tracing::warn!("DataRef input but no DataStore configured on worker");
77                        somatize_core::value::Value::Empty
78                    }
79                }
80            };
81            ctx.set("input", input_value);
82        }
83
84        match execute(&plan.plan, &mut ctx, &self.filters, self.cache.as_ref()) {
85            Ok(()) => {
86                // Find the last output
87                let output = ctx
88                    .execution_order
89                    .last()
90                    .and_then(|id| ctx.get(id))
91                    .cloned()
92                    .unwrap_or(somatize_core::value::Value::Empty);
93
94                PlanResult::Success {
95                    output,
96                    duration_ms: start.elapsed().as_millis() as u64,
97                }
98            }
99            Err(e) => PlanResult::Failed {
100                error: e.to_string(),
101                duration_ms: start.elapsed().as_millis() as u64,
102            },
103        }
104    }
105
106    /// Check if this worker matches a remote target.
107    pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
108        match target {
109            somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
110            somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use somatize_compiler::ExecutionPlan;
119    use somatize_core::cache::CacheKey;
120    use somatize_core::error::Result as SomaResult;
121    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
122    use somatize_core::value::Value;
123
124    struct TestDoubler;
125
126    impl Filter for TestDoubler {
127        fn config_hash(&self) -> CacheKey {
128            CacheKey::from_parts(&[b"TestDoubler"])
129        }
130        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
131            Ok(Value::Empty)
132        }
133        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
134            match x {
135                Value::Tensor { values, shape } => {
136                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
137                    Ok(Value::tensor(doubled, shape.clone()))
138                }
139                _ => Ok(x.clone()),
140            }
141        }
142        fn meta(&self) -> FilterMeta {
143            FilterMeta {
144                name: "TestDoubler".into(),
145                kind: FilterKind::Stateless,
146                cacheable: true,
147                differentiable: true,
148                stream_mode: StreamMode::FixedState,
149                distribution: somatize_core::filter::Distribution::Local,
150                input_schema: None,
151                output_schema: None,
152            }
153        }
154    }
155
156    fn make_worker() -> Worker {
157        Worker::new(
158            "test_worker",
159            Capabilities {
160                cpu_cores: 4,
161                ram_bytes: 8_000_000_000,
162                gpus: vec![],
163                python_envs: vec![],
164                tags: vec!["cpu".into(), "test".into()],
165            },
166        )
167    }
168
169    #[test]
170    fn worker_registration() {
171        let worker = make_worker();
172        let msg = worker.registration_message();
173        if let WorkerToCoordinator::Register {
174            worker_id,
175            capabilities,
176        } = msg
177        {
178            assert_eq!(worker_id, "test_worker");
179            assert_eq!(capabilities.cpu_cores, 4);
180        } else {
181            panic!("wrong message type");
182        }
183    }
184
185    #[test]
186    fn worker_executes_plan_successfully() {
187        let mut worker = make_worker();
188        worker.register_filter("doubler", Box::new(TestDoubler));
189
190        let plan = SerializedPlan {
191            plan_id: "p_001".into(),
192            plan: ExecutionPlan::Execute {
193                node_id: "doubler".into(),
194            },
195            input: Some(crate::protocol::InputSource::Inline {
196                value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
197            }),
198            metadata: serde_json::json!({}),
199        };
200
201        let result = worker.execute_plan(&plan);
202
203        if let PlanResult::Success {
204            output,
205            duration_ms,
206        } = result
207        {
208            let (data, _) = output.as_tensor().unwrap();
209            assert_eq!(data, &[2.0, 4.0, 6.0]);
210            assert!(duration_ms < 1000);
211        } else {
212            panic!("expected success, got: {result:?}");
213        }
214    }
215
216    #[test]
217    fn worker_handles_missing_filter() {
218        let mut worker = make_worker();
219        // Don't register any filters
220
221        let plan = SerializedPlan {
222            plan_id: "p_002".into(),
223            plan: ExecutionPlan::Execute {
224                node_id: "nonexistent".into(),
225            },
226            input: None,
227            metadata: serde_json::json!({}),
228        };
229
230        let result = worker.execute_plan(&plan);
231        assert!(matches!(result, PlanResult::Failed { .. }));
232    }
233
234    #[test]
235    fn worker_matches_target_by_id() {
236        let worker = make_worker();
237        assert!(
238            worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
239                "test_worker".into()
240            ))
241        );
242        assert!(
243            !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
244                "other".into()
245            ))
246        );
247    }
248
249    #[test]
250    fn worker_matches_target_by_tag() {
251        let worker = make_worker();
252        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
253        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
254        assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
255    }
256
257    #[test]
258    fn worker_executes_sequence() {
259        let mut worker = make_worker();
260        worker.register_filter("d1", Box::new(TestDoubler));
261        worker.register_filter("d2", Box::new(TestDoubler));
262
263        let plan = SerializedPlan {
264            plan_id: "p_003".into(),
265            plan: ExecutionPlan::Sequence(vec![
266                ExecutionPlan::Execute {
267                    node_id: "d1".into(),
268                },
269                ExecutionPlan::Execute {
270                    node_id: "d2".into(),
271                },
272            ]),
273            input: Some(crate::protocol::InputSource::Inline {
274                value: Value::tensor(vec![5.0], vec![1]),
275            }),
276            metadata: serde_json::json!({}),
277        };
278
279        let result = worker.execute_plan(&plan);
280        if let PlanResult::Success { output, .. } = result {
281            let (data, _) = output.as_tensor().unwrap();
282            assert_eq!(data, &[20.0]); // 5 * 2 * 2
283        } else {
284            panic!("expected success");
285        }
286    }
287
288    #[test]
289    fn worker_emits_events() {
290        let mut worker = make_worker();
291        worker.register_filter("doubler", Box::new(TestDoubler));
292        let mut rx = worker.subscribe();
293
294        let plan = SerializedPlan {
295            plan_id: "p_004".into(),
296            plan: ExecutionPlan::Execute {
297                node_id: "doubler".into(),
298            },
299            input: Some(crate::protocol::InputSource::Inline {
300                value: Value::tensor(vec![1.0], vec![1]),
301            }),
302            metadata: serde_json::json!({}),
303        };
304
305        worker.execute_plan(&plan);
306
307        let mut events = Vec::new();
308        while let Ok(e) = rx.try_recv() {
309            events.push(e);
310        }
311        assert!(
312            events
313                .iter()
314                .any(|e| matches!(e, Event::NodeStarted { .. }))
315        );
316        assert!(
317            events
318                .iter()
319                .any(|e| matches!(e, Event::NodeCompleted { .. }))
320        );
321    }
322}