1use crate::protocol::*;
4use somatize_core::cache::{CacheKey, CacheStore};
5use somatize_core::error::Result as SomaResult;
6use somatize_core::event::Event;
7use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
8use somatize_core::value::Value;
9use somatize_runtime::{Context, EventBus, FilterLibrary, MemoryCache, execute};
10use std::sync::Arc;
11use std::time::Instant;
12
13struct PickledFilterRunner {
16 pickled_bytes: Vec<u8>,
18 node_id: String,
20}
21
22impl Filter for PickledFilterRunner {
23 fn config_hash(&self) -> CacheKey {
24 CacheKey::from_parts(&[&self.pickled_bytes])
25 }
26
27 fn fit(&self, x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
28 self.run_python("fit", x)
29 }
30
31 fn forward(&self, x: &Value, state: &Value) -> SomaResult<Value> {
32 let input = if matches!(state, Value::Empty) {
33 x.clone()
34 } else {
35 Value::json(serde_json::json!({
36 "x": serde_json::to_value(x).unwrap_or_default(),
37 "state": serde_json::to_value(state).unwrap_or_default(),
38 }))
39 };
40 self.run_python("forward", &input)
41 }
42
43 fn meta(&self) -> FilterMeta {
44 FilterMeta {
45 name: self.node_id.clone(),
46 kind: FilterKind::Stateless,
47 cacheable: true,
48 differentiable: false,
49 stream_mode: StreamMode::FixedState,
50 distribution: somatize_core::filter::Distribution::Local,
51 input_schema: None,
52 output_schema: None,
53 }
54 }
55}
56
57impl PickledFilterRunner {
58 fn run_python(&self, method: &str, input: &Value) -> SomaResult<Value> {
59 use base64::engine::{Engine, general_purpose::STANDARD};
60
61 let input_json = serde_json::to_string(input)
62 .map_err(|e| somatize_core::error::SomaError::Other(format!("serialize input: {e}")))?;
63 let pickled_b64 = STANDARD.encode(&self.pickled_bytes);
64
65 let script = format!(
67 r#"
68import json, sys, base64, cloudpickle
69
70pickled = base64.b64decode(sys.argv[1])
71obj = cloudpickle.loads(pickled)
72input_data = json.loads(sys.argv[2])
73
74if isinstance(input_data, dict) and "x" in input_data and "state" in input_data:
75 result = obj.{method}(input_data["x"], input_data["state"])
76else:
77 result = obj.{method}(input_data, {{}})
78
79print(json.dumps(result))
80"#,
81 );
82
83 let output = std::process::Command::new("python3")
84 .args(["-c", &script, &pickled_b64, &input_json])
85 .output()
86 .map_err(|e| {
87 somatize_core::error::SomaError::Other(format!("python exec failed: {e}"))
88 })?;
89
90 if !output.status.success() {
91 let stderr = String::from_utf8_lossy(&output.stderr);
92 return Err(somatize_core::error::SomaError::Execution {
93 node_id: self.node_id.clone(),
94 message: format!("Python error: {stderr}"),
95 });
96 }
97
98 let stdout = String::from_utf8_lossy(&output.stdout);
99 let result: serde_json::Value = serde_json::from_str(stdout.trim()).map_err(|e| {
100 somatize_core::error::SomaError::Other(format!(
101 "parse python output: {e}\nstdout: {stdout}"
102 ))
103 })?;
104
105 if let Some(arr) = result.as_array() {
106 let values: Vec<f64> = arr.iter().filter_map(|v| v.as_f64()).collect();
107 if !values.is_empty() {
108 return Ok(Value::tensor(values.clone(), vec![values.len()]));
109 }
110 }
111
112 Ok(Value::json(result))
113 }
114}
115
116pub struct Worker {
118 pub id: WorkerId,
119 pub capabilities: Capabilities,
120 event_bus: Arc<EventBus>,
121 cache: Arc<dyn CacheStore>,
122 filters: FilterLibrary,
123}
124
125impl Worker {
126 pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
127 Self {
128 id: id.into(),
129 capabilities,
130 event_bus: Arc::new(EventBus::new(256)),
131 cache: Arc::new(MemoryCache::default()),
132 filters: FilterLibrary::new(),
133 }
134 }
135
136 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
138 self.cache = cache;
139 self
140 }
141
142 pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
144 self.filters.register(node_id, filter);
145 }
146
147 pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
149 self.event_bus.subscribe()
150 }
151
152 pub fn registration_message(&self) -> WorkerToCoordinator {
154 WorkerToCoordinator::Register {
155 worker_id: self.id.clone(),
156 capabilities: self.capabilities.clone(),
157 }
158 }
159
160 pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
165 let start = Instant::now();
166
167 for sf in &plan.filters {
169 let filter = Box::new(PickledFilterRunner {
170 pickled_bytes: sf.pickled_filter.clone(),
171 node_id: sf.node_id.clone(),
172 });
173 self.filters.register(&sf.node_id, filter);
174 if let Some(state) = &sf.state {
175 self.filters.set_state(&sf.node_id, state.clone());
176 }
177 }
178
179 let mut ctx = Context::new(
180 self.event_bus.clone(),
181 format!("worker_run_{}", plan.plan_id),
182 );
183
184 if let Some(input_source) = &plan.input {
186 use crate::protocol::InputSource;
187 let input_value = match input_source {
188 InputSource::Inline { value } => value.clone(),
189 InputSource::Reference { data_ref } => {
190 if let Some(store) = &ctx.data_store {
191 store
192 .get(data_ref)
193 .unwrap_or(somatize_core::value::Value::Empty)
194 } else {
195 tracing::warn!("DataRef input but no DataStore configured on worker");
196 somatize_core::value::Value::Empty
197 }
198 }
199 };
200 ctx.set("input", input_value);
201 }
202
203 match execute(&plan.plan, &mut ctx, &self.filters, self.cache.as_ref()) {
204 Ok(()) => {
205 let output = ctx
207 .execution_order
208 .last()
209 .and_then(|id| ctx.get(id))
210 .cloned()
211 .unwrap_or(somatize_core::value::Value::Empty);
212
213 PlanResult::Success {
214 output,
215 duration_ms: start.elapsed().as_millis() as u64,
216 }
217 }
218 Err(e) => PlanResult::Failed {
219 error: e.to_string(),
220 duration_ms: start.elapsed().as_millis() as u64,
221 },
222 }
223 }
224
225 pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
227 match target {
228 somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
229 somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use somatize_compiler::ExecutionPlan;
238 use somatize_core::cache::CacheKey;
239 use somatize_core::error::Result as SomaResult;
240 use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
241 use somatize_core::value::Value;
242
243 struct TestDoubler;
244
245 impl Filter for TestDoubler {
246 fn config_hash(&self) -> CacheKey {
247 CacheKey::from_parts(&[b"TestDoubler"])
248 }
249 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
250 Ok(Value::Empty)
251 }
252 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
253 match x {
254 Value::Tensor { values, shape } => {
255 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
256 Ok(Value::tensor(doubled, shape.clone()))
257 }
258 _ => Ok(x.clone()),
259 }
260 }
261 fn meta(&self) -> FilterMeta {
262 FilterMeta {
263 name: "TestDoubler".into(),
264 kind: FilterKind::Stateless,
265 cacheable: true,
266 differentiable: true,
267 stream_mode: StreamMode::FixedState,
268 distribution: somatize_core::filter::Distribution::Local,
269 input_schema: None,
270 output_schema: None,
271 }
272 }
273 }
274
275 fn make_worker() -> Worker {
276 Worker::new(
277 "test_worker",
278 Capabilities {
279 cpu_cores: 4,
280 ram_bytes: 8_000_000_000,
281 gpus: vec![],
282 python_envs: vec![],
283 tags: vec!["cpu".into(), "test".into()],
284 },
285 )
286 }
287
288 #[test]
289 fn worker_registration() {
290 let worker = make_worker();
291 let msg = worker.registration_message();
292 if let WorkerToCoordinator::Register {
293 worker_id,
294 capabilities,
295 } = msg
296 {
297 assert_eq!(worker_id, "test_worker");
298 assert_eq!(capabilities.cpu_cores, 4);
299 } else {
300 panic!("wrong message type");
301 }
302 }
303
304 #[test]
305 fn worker_executes_plan_successfully() {
306 let mut worker = make_worker();
307 worker.register_filter("doubler", Box::new(TestDoubler));
308
309 let plan = SerializedPlan {
310 plan_id: "p_001".into(),
311 plan: ExecutionPlan::Execute {
312 node_id: "doubler".into(),
313 },
314 input: Some(crate::protocol::InputSource::Inline {
315 value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
316 }),
317 filters: vec![],
318 metadata: serde_json::json!({}),
319 };
320
321 let result = worker.execute_plan(&plan);
322
323 if let PlanResult::Success {
324 output,
325 duration_ms,
326 } = result
327 {
328 let (data, _) = output.as_tensor().unwrap();
329 assert_eq!(data, &[2.0, 4.0, 6.0]);
330 assert!(duration_ms < 1000);
331 } else {
332 panic!("expected success, got: {result:?}");
333 }
334 }
335
336 #[test]
337 fn worker_handles_missing_filter() {
338 let mut worker = make_worker();
339 let plan = SerializedPlan {
342 plan_id: "p_002".into(),
343 plan: ExecutionPlan::Execute {
344 node_id: "nonexistent".into(),
345 },
346 input: None,
347 filters: vec![],
348 metadata: serde_json::json!({}),
349 };
350
351 let result = worker.execute_plan(&plan);
352 assert!(matches!(result, PlanResult::Failed { .. }));
353 }
354
355 #[test]
356 fn worker_matches_target_by_id() {
357 let worker = make_worker();
358 assert!(
359 worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
360 "test_worker".into()
361 ))
362 );
363 assert!(
364 !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
365 "other".into()
366 ))
367 );
368 }
369
370 #[test]
371 fn worker_matches_target_by_tag() {
372 let worker = make_worker();
373 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
374 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
375 assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
376 }
377
378 #[test]
379 fn worker_executes_sequence() {
380 let mut worker = make_worker();
381 worker.register_filter("d1", Box::new(TestDoubler));
382 worker.register_filter("d2", Box::new(TestDoubler));
383
384 let plan = SerializedPlan {
385 plan_id: "p_003".into(),
386 plan: ExecutionPlan::Sequence(vec![
387 ExecutionPlan::Execute {
388 node_id: "d1".into(),
389 },
390 ExecutionPlan::Execute {
391 node_id: "d2".into(),
392 },
393 ]),
394 input: Some(crate::protocol::InputSource::Inline {
395 value: Value::tensor(vec![5.0], vec![1]),
396 }),
397 filters: vec![],
398 metadata: serde_json::json!({}),
399 };
400
401 let result = worker.execute_plan(&plan);
402 if let PlanResult::Success { output, .. } = result {
403 let (data, _) = output.as_tensor().unwrap();
404 assert_eq!(data, &[20.0]); } else {
406 panic!("expected success");
407 }
408 }
409
410 #[test]
411 fn worker_emits_events() {
412 let mut worker = make_worker();
413 worker.register_filter("doubler", Box::new(TestDoubler));
414 let mut rx = worker.subscribe();
415
416 let plan = SerializedPlan {
417 plan_id: "p_004".into(),
418 plan: ExecutionPlan::Execute {
419 node_id: "doubler".into(),
420 },
421 input: Some(crate::protocol::InputSource::Inline {
422 value: Value::tensor(vec![1.0], vec![1]),
423 }),
424 filters: vec![],
425 metadata: serde_json::json!({}),
426 };
427
428 worker.execute_plan(&plan);
429
430 let mut events = Vec::new();
431 while let Ok(e) = rx.try_recv() {
432 events.push(e);
433 }
434 assert!(
435 events
436 .iter()
437 .any(|e| matches!(e, Event::NodeStarted { .. }))
438 );
439 assert!(
440 events
441 .iter()
442 .any(|e| matches!(e, Event::NodeCompleted { .. }))
443 );
444 }
445}