Skip to main content

somatize_worker/
worker.rs

1//! Worker — receives and executes plans from a coordinator.
2
3use crate::protocol::*;
4use somatize_core::cache::{CacheKey, CacheStore};
5use somatize_core::error::Result as SomaResult;
6use somatize_core::event::Event;
7use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
8use somatize_core::value::Value;
9use somatize_runtime::{Context, EventBus, FilterLibrary, MemoryCache, execute};
10use std::sync::Arc;
11use std::time::Instant;
12
13/// A filter reconstructed from cloudpickle bytes.
14/// Deserializes the Python object on the worker and executes methods via subprocess.
15struct PickledFilterRunner {
16    /// cloudpickle.dumps() bytes of the original Python filter object.
17    pickled_bytes: Vec<u8>,
18    /// Node ID (for error messages).
19    node_id: String,
20}
21
22impl Filter for PickledFilterRunner {
23    fn config_hash(&self) -> CacheKey {
24        CacheKey::from_parts(&[&self.pickled_bytes])
25    }
26
27    fn fit(&self, x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
28        self.run_python("fit", x)
29    }
30
31    fn forward(&self, x: &Value, state: &Value) -> SomaResult<Value> {
32        let input = if matches!(state, Value::Empty) {
33            x.clone()
34        } else {
35            Value::json(serde_json::json!({
36                "x": serde_json::to_value(x).unwrap_or_default(),
37                "state": serde_json::to_value(state).unwrap_or_default(),
38            }))
39        };
40        self.run_python("forward", &input)
41    }
42
43    fn meta(&self) -> FilterMeta {
44        FilterMeta {
45            name: self.node_id.clone(),
46            kind: FilterKind::Stateless,
47            cacheable: true,
48            differentiable: false,
49            stream_mode: StreamMode::FixedState,
50            distribution: somatize_core::filter::Distribution::Local,
51            input_schema: None,
52            output_schema: None,
53        }
54    }
55}
56
57impl PickledFilterRunner {
58    fn run_python(&self, method: &str, input: &Value) -> SomaResult<Value> {
59        use base64::engine::{Engine, general_purpose::STANDARD};
60
61        let input_json = serde_json::to_string(input)
62            .map_err(|e| somatize_core::error::SomaError::Other(format!("serialize input: {e}")))?;
63        let pickled_b64 = STANDARD.encode(&self.pickled_bytes);
64
65        // Python script: deserialize filter with cloudpickle, call method, return JSON
66        let script = format!(
67            r#"
68import json, sys, base64, cloudpickle
69
70pickled = base64.b64decode(sys.argv[1])
71obj = cloudpickle.loads(pickled)
72input_data = json.loads(sys.argv[2])
73
74if isinstance(input_data, dict) and "x" in input_data and "state" in input_data:
75    result = obj.{method}(input_data["x"], input_data["state"])
76else:
77    result = obj.{method}(input_data, {{}})
78
79print(json.dumps(result))
80"#,
81        );
82
83        let output = std::process::Command::new("python3")
84            .args(["-c", &script, &pickled_b64, &input_json])
85            .output()
86            .map_err(|e| {
87                somatize_core::error::SomaError::Other(format!("python exec failed: {e}"))
88            })?;
89
90        if !output.status.success() {
91            let stderr = String::from_utf8_lossy(&output.stderr);
92            return Err(somatize_core::error::SomaError::Execution {
93                node_id: self.node_id.clone(),
94                message: format!("Python error: {stderr}"),
95            });
96        }
97
98        let stdout = String::from_utf8_lossy(&output.stdout);
99        let result: serde_json::Value = serde_json::from_str(stdout.trim()).map_err(|e| {
100            somatize_core::error::SomaError::Other(format!(
101                "parse python output: {e}\nstdout: {stdout}"
102            ))
103        })?;
104
105        if let Some(arr) = result.as_array() {
106            let values: Vec<f64> = arr.iter().filter_map(|v| v.as_f64()).collect();
107            if !values.is_empty() {
108                return Ok(Value::tensor(values.clone(), vec![values.len()]));
109            }
110        }
111
112        Ok(Value::json(result))
113    }
114}
115
116/// Worker state: manages execution of plans received from a coordinator.
117pub struct Worker {
118    pub id: WorkerId,
119    pub capabilities: Capabilities,
120    event_bus: Arc<EventBus>,
121    cache: Arc<dyn CacheStore>,
122    filters: FilterLibrary,
123}
124
125impl Worker {
126    pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
127        Self {
128            id: id.into(),
129            capabilities,
130            event_bus: Arc::new(EventBus::new(256)),
131            cache: Arc::new(MemoryCache::default()),
132            filters: FilterLibrary::new(),
133        }
134    }
135
136    /// Set a custom cache store (e.g. tiered or shared).
137    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
138        self.cache = cache;
139        self
140    }
141
142    /// Register a filter that this worker can execute.
143    pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
144        self.filters.register(node_id, filter);
145    }
146
147    /// Subscribe to execution events.
148    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
149        self.event_bus.subscribe()
150    }
151
152    /// Build a registration message.
153    pub fn registration_message(&self) -> WorkerToCoordinator {
154        WorkerToCoordinator::Register {
155            worker_id: self.id.clone(),
156            capabilities: self.capabilities.clone(),
157        }
158    }
159
160    /// Execute a serialized plan.
161    ///
162    /// If the plan contains serialized filter definitions, they are registered
163    /// temporarily for this execution (alongside any pre-registered filters).
164    pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
165        let start = Instant::now();
166
167        // Register pickled filters (from remote client via cloudpickle)
168        for sf in &plan.filters {
169            let filter = Box::new(PickledFilterRunner {
170                pickled_bytes: sf.pickled_filter.clone(),
171                node_id: sf.node_id.clone(),
172            });
173            self.filters.register(&sf.node_id, filter);
174            if let Some(state) = &sf.state {
175                self.filters.set_state(&sf.node_id, state.clone());
176            }
177        }
178
179        let mut ctx = Context::new(
180            self.event_bus.clone(),
181            format!("worker_run_{}", plan.plan_id),
182        );
183
184        // Resolve input data (inline or from DataStore)
185        if let Some(input_source) = &plan.input {
186            use crate::protocol::InputSource;
187            let input_value = match input_source {
188                InputSource::Inline { value } => value.clone(),
189                InputSource::Reference { data_ref } => {
190                    if let Some(store) = &ctx.data_store {
191                        store
192                            .get(data_ref)
193                            .unwrap_or(somatize_core::value::Value::Empty)
194                    } else {
195                        tracing::warn!("DataRef input but no DataStore configured on worker");
196                        somatize_core::value::Value::Empty
197                    }
198                }
199            };
200            ctx.set("input", input_value);
201        }
202
203        match execute(&plan.plan, &mut ctx, &self.filters, self.cache.as_ref()) {
204            Ok(()) => {
205                // Find the last output
206                let output = ctx
207                    .execution_order
208                    .last()
209                    .and_then(|id| ctx.get(id))
210                    .cloned()
211                    .unwrap_or(somatize_core::value::Value::Empty);
212
213                PlanResult::Success {
214                    output,
215                    duration_ms: start.elapsed().as_millis() as u64,
216                }
217            }
218            Err(e) => PlanResult::Failed {
219                error: e.to_string(),
220                duration_ms: start.elapsed().as_millis() as u64,
221            },
222        }
223    }
224
225    /// Check if this worker matches a remote target.
226    pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
227        match target {
228            somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
229            somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
230        }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use somatize_compiler::ExecutionPlan;
238    use somatize_core::cache::CacheKey;
239    use somatize_core::error::Result as SomaResult;
240    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
241    use somatize_core::value::Value;
242
243    struct TestDoubler;
244
245    impl Filter for TestDoubler {
246        fn config_hash(&self) -> CacheKey {
247            CacheKey::from_parts(&[b"TestDoubler"])
248        }
249        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
250            Ok(Value::Empty)
251        }
252        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
253            match x {
254                Value::Tensor { values, shape } => {
255                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
256                    Ok(Value::tensor(doubled, shape.clone()))
257                }
258                _ => Ok(x.clone()),
259            }
260        }
261        fn meta(&self) -> FilterMeta {
262            FilterMeta {
263                name: "TestDoubler".into(),
264                kind: FilterKind::Stateless,
265                cacheable: true,
266                differentiable: true,
267                stream_mode: StreamMode::FixedState,
268                distribution: somatize_core::filter::Distribution::Local,
269                input_schema: None,
270                output_schema: None,
271            }
272        }
273    }
274
275    fn make_worker() -> Worker {
276        Worker::new(
277            "test_worker",
278            Capabilities {
279                cpu_cores: 4,
280                ram_bytes: 8_000_000_000,
281                gpus: vec![],
282                python_envs: vec![],
283                tags: vec!["cpu".into(), "test".into()],
284            },
285        )
286    }
287
288    #[test]
289    fn worker_registration() {
290        let worker = make_worker();
291        let msg = worker.registration_message();
292        if let WorkerToCoordinator::Register {
293            worker_id,
294            capabilities,
295        } = msg
296        {
297            assert_eq!(worker_id, "test_worker");
298            assert_eq!(capabilities.cpu_cores, 4);
299        } else {
300            panic!("wrong message type");
301        }
302    }
303
304    #[test]
305    fn worker_executes_plan_successfully() {
306        let mut worker = make_worker();
307        worker.register_filter("doubler", Box::new(TestDoubler));
308
309        let plan = SerializedPlan {
310            plan_id: "p_001".into(),
311            plan: ExecutionPlan::Execute {
312                node_id: "doubler".into(),
313            },
314            input: Some(crate::protocol::InputSource::Inline {
315                value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
316            }),
317            filters: vec![],
318            metadata: serde_json::json!({}),
319        };
320
321        let result = worker.execute_plan(&plan);
322
323        if let PlanResult::Success {
324            output,
325            duration_ms,
326        } = result
327        {
328            let (data, _) = output.as_tensor().unwrap();
329            assert_eq!(data, &[2.0, 4.0, 6.0]);
330            assert!(duration_ms < 1000);
331        } else {
332            panic!("expected success, got: {result:?}");
333        }
334    }
335
336    #[test]
337    fn worker_handles_missing_filter() {
338        let mut worker = make_worker();
339        // Don't register any filters
340
341        let plan = SerializedPlan {
342            plan_id: "p_002".into(),
343            plan: ExecutionPlan::Execute {
344                node_id: "nonexistent".into(),
345            },
346            input: None,
347            filters: vec![],
348            metadata: serde_json::json!({}),
349        };
350
351        let result = worker.execute_plan(&plan);
352        assert!(matches!(result, PlanResult::Failed { .. }));
353    }
354
355    #[test]
356    fn worker_matches_target_by_id() {
357        let worker = make_worker();
358        assert!(
359            worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
360                "test_worker".into()
361            ))
362        );
363        assert!(
364            !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
365                "other".into()
366            ))
367        );
368    }
369
370    #[test]
371    fn worker_matches_target_by_tag() {
372        let worker = make_worker();
373        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
374        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
375        assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
376    }
377
378    #[test]
379    fn worker_executes_sequence() {
380        let mut worker = make_worker();
381        worker.register_filter("d1", Box::new(TestDoubler));
382        worker.register_filter("d2", Box::new(TestDoubler));
383
384        let plan = SerializedPlan {
385            plan_id: "p_003".into(),
386            plan: ExecutionPlan::Sequence(vec![
387                ExecutionPlan::Execute {
388                    node_id: "d1".into(),
389                },
390                ExecutionPlan::Execute {
391                    node_id: "d2".into(),
392                },
393            ]),
394            input: Some(crate::protocol::InputSource::Inline {
395                value: Value::tensor(vec![5.0], vec![1]),
396            }),
397            filters: vec![],
398            metadata: serde_json::json!({}),
399        };
400
401        let result = worker.execute_plan(&plan);
402        if let PlanResult::Success { output, .. } = result {
403            let (data, _) = output.as_tensor().unwrap();
404            assert_eq!(data, &[20.0]); // 5 * 2 * 2
405        } else {
406            panic!("expected success");
407        }
408    }
409
410    #[test]
411    fn worker_emits_events() {
412        let mut worker = make_worker();
413        worker.register_filter("doubler", Box::new(TestDoubler));
414        let mut rx = worker.subscribe();
415
416        let plan = SerializedPlan {
417            plan_id: "p_004".into(),
418            plan: ExecutionPlan::Execute {
419                node_id: "doubler".into(),
420            },
421            input: Some(crate::protocol::InputSource::Inline {
422                value: Value::tensor(vec![1.0], vec![1]),
423            }),
424            filters: vec![],
425            metadata: serde_json::json!({}),
426        };
427
428        worker.execute_plan(&plan);
429
430        let mut events = Vec::new();
431        while let Ok(e) = rx.try_recv() {
432            events.push(e);
433        }
434        assert!(
435            events
436                .iter()
437                .any(|e| matches!(e, Event::NodeStarted { .. }))
438        );
439        assert!(
440            events
441                .iter()
442                .any(|e| matches!(e, Event::NodeCompleted { .. }))
443        );
444    }
445}