1use crate::protocol::*;
4use somatize_core::cache::{CacheKey, CacheStore};
5use somatize_core::error::Result as SomaResult;
6use somatize_core::event::Event;
7use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
8use somatize_core::store::{DataStore, LocalDataStore};
9use somatize_core::value::Value;
10use somatize_runtime::{Context, EventBus, FilterLibrary, MemoryCache, execute};
11use std::sync::Arc;
12use std::time::Instant;
13
14pub(crate) struct PickledFilterRunner {
17 pub(crate) pickled_bytes: Vec<u8>,
19 pub(crate) node_id: String,
21 pub(crate) python_path: String,
23 pub(crate) requirements: Vec<String>,
25 pub(crate) trainable: bool,
27}
28
29impl Filter for PickledFilterRunner {
30 fn config_hash(&self) -> CacheKey {
31 CacheKey::from_parts(&[&self.pickled_bytes])
32 }
33
34 fn fit(&self, x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
35 self.run_python("fit", x)
36 }
37
38 fn forward(&self, x: &Value, state: &Value) -> SomaResult<Value> {
39 let input = if matches!(state, Value::Empty) {
40 x.clone()
41 } else {
42 Value::json(serde_json::json!({
43 "x": serde_json::to_value(x).unwrap_or_default(),
44 "state": serde_json::to_value(state).unwrap_or_default(),
45 }))
46 };
47 self.run_python("forward", &input)
48 }
49
50 fn meta(&self) -> FilterMeta {
51 FilterMeta {
52 name: self.node_id.clone(),
53 kind: if self.trainable {
54 FilterKind::Trainable
55 } else {
56 FilterKind::Stateless
57 },
58 cacheable: true,
59 differentiable: false,
60 stream_mode: StreamMode::FixedState,
61 distribution: somatize_core::filter::Distribution::Local,
62 input_schema: None,
63 output_schema: None,
64 }
65 }
66}
67
68impl PickledFilterRunner {
69 fn run_python(&self, method: &str, input: &Value) -> SomaResult<Value> {
70 self.run_python_with_retry(method, input, true)
71 }
72
73 fn run_python_with_retry(
74 &self,
75 method: &str,
76 input: &Value,
77 allow_retry: bool,
78 ) -> SomaResult<Value> {
79 use base64::engine::{Engine, general_purpose::STANDARD};
80 use std::io::Write;
81
82 let input_json = serde_json::to_string(input)
83 .map_err(|e| somatize_core::error::SomaError::Other(format!("serialize input: {e}")))?;
84 let pickled_b64 = STANDARD.encode(&self.pickled_bytes);
85
86 let script = format!(
87 r#"
88import json, sys, base64, cloudpickle
89
90def unwrap_value(v):
91 """Convert Soma Value JSON to native Python types."""
92 if isinstance(v, dict) and "type" in v and "data" in v:
93 t = v["type"]
94 d = v["data"]
95 if t == "Tensor":
96 return d.get("values", [])
97 if t == "Json":
98 return d
99 if t == "Empty":
100 return {{}}
101 if t == "Bytes":
102 return bytes(d)
103 return v
104
105pickled_b64 = sys.stdin.readline().strip()
106input_line = sys.stdin.read()
107
108pickled = base64.b64decode(pickled_b64)
109obj = cloudpickle.loads(pickled)
110raw = json.loads(input_line)
111input_data = unwrap_value(raw)
112
113if isinstance(input_data, dict) and "x" in input_data and "state" in input_data:
114 x = unwrap_value(input_data["x"])
115 state = unwrap_value(input_data["state"])
116 result = obj.{method}(x, state)
117else:
118 result = obj.{method}(input_data, {{}})
119
120print(json.dumps(result))
121"#,
122 );
123
124 let mut child = std::process::Command::new(&self.python_path)
125 .args(["-c", &script])
126 .stdin(std::process::Stdio::piped())
127 .stdout(std::process::Stdio::piped())
128 .stderr(std::process::Stdio::piped())
129 .spawn()
130 .map_err(|e| {
131 somatize_core::error::SomaError::Other(format!("python spawn failed: {e}"))
132 })?;
133
134 if let Some(mut stdin) = child.stdin.take() {
135 let _ = writeln!(stdin, "{pickled_b64}");
136 let _ = write!(stdin, "{input_json}");
137 }
138
139 let output = child.wait_with_output().map_err(|e| {
140 somatize_core::error::SomaError::Other(format!("python exec failed: {e}"))
141 })?;
142
143 if !output.status.success() {
144 let stderr = String::from_utf8_lossy(&output.stderr);
145
146 if allow_retry && stderr.contains("ModuleNotFoundError") {
148 let missing = Self::parse_missing_module(&stderr);
149 let mut to_install: Vec<String> = self.requirements.clone();
151 if let Some(ref m) = missing
152 && !to_install.iter().any(|r| r == m)
153 {
154 to_install.push(m.clone());
155 }
156 if !to_install.is_empty() {
157 let names = to_install.join(", ");
158 tracing::warn!(
159 "Missing module for filter '{}', installing: {names}",
160 self.node_id
161 );
162 let mut args = vec!["-m", "pip", "install", "--quiet"];
163 let refs: Vec<&str> = to_install.iter().map(|s| s.as_str()).collect();
164 args.extend(refs);
165 let install = std::process::Command::new(&self.python_path)
166 .args(&args)
167 .output();
168 if let Ok(res) = install
169 && res.status.success()
170 {
171 tracing::info!("Installed [{names}], retrying...");
172 return self.run_python_with_retry(method, input, false);
173 }
174 }
175 }
176
177 return Err(somatize_core::error::SomaError::Execution {
178 node_id: self.node_id.clone(),
179 message: format!("Python error: {stderr}"),
180 });
181 }
182
183 let stdout = String::from_utf8_lossy(&output.stdout);
184 let result: serde_json::Value = serde_json::from_str(stdout.trim()).map_err(|e| {
185 somatize_core::error::SomaError::Other(format!(
186 "parse python output: {e}\nstdout: {stdout}"
187 ))
188 })?;
189
190 if let Some(arr) = result.as_array() {
191 let values: Vec<f64> = arr.iter().filter_map(|v| v.as_f64()).collect();
192 if !values.is_empty() {
193 return Ok(Value::tensor(values.clone(), vec![values.len()]));
194 }
195 }
196
197 Ok(Value::json(result))
198 }
199
200 fn parse_missing_module(stderr: &str) -> Option<String> {
202 for line in stderr.lines().rev() {
203 if line.contains("ModuleNotFoundError") {
204 if let Some(start) = line.find('\'') {
206 let rest = &line[start + 1..];
207 if let Some(end) = rest.find('\'') {
208 return Some(rest[..end].split('.').next()?.to_string());
209 }
210 }
211 }
212 }
213 None
214 }
215}
216
217pub struct Worker {
219 pub id: WorkerId,
220 pub capabilities: Capabilities,
221 event_bus: Arc<EventBus>,
222 cache: Arc<dyn CacheStore>,
223 filters: FilterLibrary,
224 data_store: Option<Arc<dyn DataStore>>,
226 temp_store: Arc<LocalDataStore>,
228 env_manager: crate::env_manager::EnvManager,
230}
231
232impl Worker {
233 pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
234 let worker_id: String = id.into();
235 let temp_path = std::env::temp_dir().join(format!("soma-uploads-{worker_id}"));
236 let temp_store = LocalDataStore::new(temp_path);
237 let env_path = std::env::temp_dir().join(format!("soma-envs-{worker_id}"));
238 Self {
239 id: worker_id,
240 capabilities,
241 event_bus: Arc::new(EventBus::new(256)),
242 cache: Arc::new(MemoryCache::default()),
243 filters: FilterLibrary::new(),
244 data_store: None,
245 temp_store: Arc::new(temp_store),
246 env_manager: crate::env_manager::EnvManager::new(
247 env_path,
248 crate::env_manager::EnvType::Venv,
249 ),
250 }
251 }
252
253 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
255 self.cache = cache;
256 self
257 }
258
259 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
261 self.data_store = Some(store);
262 self
263 }
264
265 pub fn with_temp_dir(mut self, path: std::path::PathBuf) -> Self {
267 self.temp_store = Arc::new(LocalDataStore::new(path));
268 self
269 }
270
271 pub fn temp_store(&self) -> &Arc<LocalDataStore> {
273 &self.temp_store
274 }
275
276 pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
278 self.filters.register(node_id, filter);
279 }
280
281 pub fn get_filter(&self, node_id: &str) -> Option<Arc<dyn Filter>> {
283 self.filters.get(node_id)
284 }
285
286 pub fn get_filter_state(&self, node_id: &str) -> Value {
288 self.filters
289 .get_state(node_id)
290 .cloned()
291 .unwrap_or(Value::Empty)
292 }
293
294 pub fn set_filter_state(&mut self, node_id: &str, state: Value) {
296 self.filters.set_state(node_id, state);
297 }
298
299 pub fn wrap_output(&self, output: Value) -> OutputDelivery {
301 let size = serde_json::to_vec(&output).map(|v| v.len()).unwrap_or(0);
302 if size >= somatize_core::store::INLINE_THRESHOLD_BYTES {
303 let key = somatize_core::cache::CacheKey::hash_data(
304 &serde_json::to_vec(&output).unwrap_or_default(),
305 );
306 if let Ok(data_ref) = self.temp_store.put(&key, &output) {
307 return OutputDelivery::Reference { data_ref };
308 }
309 }
310 OutputDelivery::Inline { value: output }
311 }
312
313 pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
315 self.event_bus.subscribe()
316 }
317
318 pub fn registration_message(&self) -> WorkerToCoordinator {
320 WorkerToCoordinator::Register {
321 worker_id: self.id.clone(),
322 capabilities: self.capabilities.clone(),
323 }
324 }
325
326 pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
336 let start = Instant::now();
337
338 let all_reqs: Vec<String> = plan
340 .filters
341 .iter()
342 .flat_map(|sf| sf.requirements.iter().cloned())
343 .collect::<std::collections::HashSet<_>>()
344 .into_iter()
345 .collect();
346
347 let python_path = if all_reqs.is_empty() {
349 "python3".to_string()
350 } else {
351 let reqs_str = all_reqs.join("\n");
352 match self.env_manager.ensure_env(&plan.plan_id, &reqs_str) {
353 Ok(path) => {
354 tracing::info!("Using venv for plan {}: {:?}", plan.plan_id, path);
355 path.to_string_lossy().to_string()
356 }
357 Err(e) => {
358 tracing::warn!("Failed to create venv, falling back to system python: {e}");
359 "python3".to_string()
360 }
361 }
362 };
363
364 #[cfg(feature = "embedded-python")]
366 let site_packages = if python_path != "python3" {
367 let venv_dir = std::path::Path::new(&python_path)
369 .parent()
370 .and_then(|bin| bin.parent());
371 venv_dir.and_then(|d| {
372 std::fs::read_dir(d.join("lib"))
373 .ok()?
374 .filter_map(|e| e.ok())
375 .find(|e| e.file_name().to_string_lossy().starts_with("python"))
376 .map(|e| e.path().join("site-packages").to_string_lossy().to_string())
377 })
378 } else {
379 None
380 };
381
382 for sf in &plan.filters {
384 let filter: Box<dyn Filter> = {
385 #[cfg(feature = "embedded-python")]
386 {
387 match crate::py_filter::EmbeddedPyFilter::new(
388 &sf.pickled_filter,
389 sf.node_id.clone(),
390 sf.trainable,
391 site_packages.as_deref(),
392 ) {
393 Ok(embedded) => {
394 tracing::info!("Using embedded PyO3 filter for '{}'", sf.node_id);
395 Box::new(embedded)
396 }
397 Err(e) => {
398 tracing::warn!(
399 "PyO3 embed failed for '{}': {e}, falling back to subprocess",
400 sf.node_id
401 );
402 Box::new(PickledFilterRunner {
403 pickled_bytes: sf.pickled_filter.clone(),
404 node_id: sf.node_id.clone(),
405 python_path: python_path.clone(),
406 requirements: sf.requirements.clone(),
407 trainable: sf.trainable,
408 })
409 }
410 }
411 }
412 #[cfg(not(feature = "embedded-python"))]
413 {
414 Box::new(PickledFilterRunner {
415 pickled_bytes: sf.pickled_filter.clone(),
416 node_id: sf.node_id.clone(),
417 python_path: python_path.clone(),
418 requirements: sf.requirements.clone(),
419 trainable: sf.trainable,
420 })
421 }
422 };
423 self.filters.register(&sf.node_id, filter);
424 if let Some(state) = &sf.state {
425 self.filters.set_state(&sf.node_id, state.clone());
426 }
427 }
428
429 let input_value = plan.input.as_ref().map(|src| match src {
431 InputSource::Inline { value } => value.clone(),
432 InputSource::Reference { data_ref } => {
433 if let Some(store) = &self.data_store
435 && let Ok(val) = store.get(data_ref)
436 {
437 return val;
438 }
439 self.temp_store.get(data_ref).unwrap_or_else(|e| {
440 tracing::warn!("Failed to resolve DataRef: {e}");
441 Value::Empty
442 })
443 }
444 });
445
446 if let Some(InputSource::Reference { data_ref }) = &plan.input
449 && let Some(store) = self.data_store.clone()
450 && let Ok(meta) = store.meta(data_ref)
451 && meta.total_rows > 1024
452 {
453 return self.execute_streamed_from_store(plan, &store, data_ref, &meta, start);
454 }
455
456 match &plan.mode {
457 ExecutionMode::Fit { y } => self.execute_fit(plan, input_value, y.as_ref(), start),
458 ExecutionMode::Forward => self.execute_forward(plan, input_value, start),
459 }
460 }
461
462 fn execute_forward(
464 &mut self,
465 plan: &SerializedPlan,
466 input: Option<Value>,
467 start: Instant,
468 ) -> PlanResult {
469 let mut ctx = Context::new(
470 self.event_bus.clone(),
471 format!("worker_run_{}", plan.plan_id),
472 );
473
474 if let Some(val) = input {
475 ctx.set("input", val.clone());
476 if let somatize_compiler::ExecutionPlan::Execute { node_id } = &plan.plan {
478 ctx.set(format!("__input_{node_id}"), val);
479 }
480 }
481
482 match execute(&plan.plan, &mut ctx, &self.filters, self.cache.as_ref()) {
483 Ok(()) => {
484 let output = ctx
485 .execution_order
486 .last()
487 .and_then(|id| ctx.get(id))
488 .cloned()
489 .unwrap_or(Value::Empty);
490
491 PlanResult::Success {
492 output: self.wrap_output(output),
493 duration_ms: start.elapsed().as_millis() as u64,
494 states: std::collections::HashMap::new(),
495 }
496 }
497 Err(e) => PlanResult::Failed {
498 error: e.to_string(),
499 duration_ms: start.elapsed().as_millis() as u64,
500 },
501 }
502 }
503
504 fn execute_fit(
506 &mut self,
507 plan: &SerializedPlan,
508 input: Option<Value>,
509 y: Option<&Value>,
510 start: Instant,
511 ) -> PlanResult {
512 let run_id = format!("worker_fit_{}", plan.plan_id);
513 let x = input.unwrap_or(Value::Empty);
514
515 let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
517 let mut outputs: std::collections::HashMap<String, Value> =
518 std::collections::HashMap::new();
519 let mut trained_states: std::collections::HashMap<String, Value> =
520 std::collections::HashMap::new();
521
522 for node_id in &node_ids {
523 let filter = match self.filters.get(node_id) {
524 Some(f) => f,
525 None => {
526 return PlanResult::Failed {
527 error: format!("filter not found: {node_id}"),
528 duration_ms: start.elapsed().as_millis() as u64,
529 };
530 }
531 };
532
533 let meta = filter.meta();
534
535 self.event_bus.emit(Event::NodeStarted {
536 run_id: run_id.clone(),
537 node_id: node_id.to_string(),
538 kind: meta.kind,
539 });
540
541 let node_start = Instant::now();
542
543 let node_input = outputs
545 .values()
546 .last()
547 .cloned()
548 .unwrap_or_else(|| x.clone());
549
550 let state = if meta.kind == FilterKind::Trainable {
552 match filter.fit(&node_input, y) {
553 Ok(s) => {
554 self.filters.set_state(node_id, s.clone());
555 trained_states.insert(node_id.clone(), s.clone());
556 s
557 }
558 Err(e) => {
559 return PlanResult::Failed {
560 error: format!("fit({node_id}): {e}"),
561 duration_ms: start.elapsed().as_millis() as u64,
562 };
563 }
564 }
565 } else {
566 self.filters
567 .get_state(node_id)
568 .cloned()
569 .unwrap_or(Value::Empty)
570 };
571
572 match filter.forward(&node_input, &state) {
574 Ok(output) => {
575 self.event_bus.emit(Event::NodeCompleted {
576 run_id: run_id.clone(),
577 node_id: node_id.to_string(),
578 duration: node_start.elapsed(),
579 output_summary: format!("{output}"),
580 });
581 outputs.insert(node_id.clone(), output);
582 }
583 Err(e) => {
584 return PlanResult::Failed {
585 error: format!("forward({node_id}): {e}"),
586 duration_ms: start.elapsed().as_millis() as u64,
587 };
588 }
589 }
590 }
591
592 let output = outputs.values().last().cloned().unwrap_or(Value::Empty);
593
594 PlanResult::Success {
595 output: self.wrap_output(output),
596 duration_ms: start.elapsed().as_millis() as u64,
597 states: trained_states,
598 }
599 }
600
601 fn execute_streamed_from_store(
604 &mut self,
605 plan: &SerializedPlan,
606 store: &Arc<dyn DataStore>,
607 data_ref: &somatize_core::store::DataRef,
608 meta: &somatize_core::store::StoreMeta,
609 start: Instant,
610 ) -> PlanResult {
611 use somatize_runtime::stream::{FittedFilter, StreamExecutor};
612
613 let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
614 let fitted: Vec<FittedFilter> = node_ids
615 .iter()
616 .filter_map(|id| {
617 let filter = self.filters.get(id)?;
618 let state = self.filters.get_state(id).cloned().unwrap_or(Value::Empty);
619 Some(FittedFilter {
620 name: id.clone(),
621 filter,
622 state,
623 })
624 })
625 .collect();
626
627 let mut executor = StreamExecutor::new(fitted);
628 let chunk_size = 1024;
629 let run_id = format!("worker_stream_{}", plan.plan_id);
630
631 self.event_bus.emit(Event::RunStarted {
632 run_id: run_id.clone(),
633 plan_summary: somatize_core::event::PlanSummary {
634 total_nodes: node_ids.len(),
635 cached_nodes: 0,
636 parallel_branches: 0,
637 },
638 });
639
640 let mut last_output = Value::Empty;
641 let total = meta.total_rows;
642 let mut chunk_idx = 0;
643
644 for row_start in (0..total).step_by(chunk_size) {
645 let len = chunk_size.min(total - row_start);
646 let chunk = match store.get_rows(data_ref, row_start, len) {
647 Ok(c) => c,
648 Err(e) => {
649 return PlanResult::Failed {
650 error: format!("get_rows({row_start}..{}): {e}", row_start + len),
651 duration_ms: start.elapsed().as_millis() as u64,
652 };
653 }
654 };
655
656 match executor.process_chunk(chunk) {
657 Ok(Some(output)) => last_output = output,
658 Ok(None) => {} Err(e) => {
660 return PlanResult::Failed {
661 error: format!("stream chunk {chunk_idx}: {e}"),
662 duration_ms: start.elapsed().as_millis() as u64,
663 };
664 }
665 }
666 chunk_idx += 1;
667 }
668
669 match executor.flush() {
671 Ok(Some(output)) => last_output = output,
672 Ok(None) => {}
673 Err(e) => {
674 return PlanResult::Failed {
675 error: format!("stream flush: {e}"),
676 duration_ms: start.elapsed().as_millis() as u64,
677 };
678 }
679 }
680
681 tracing::info!(
682 "Streamed {chunk_idx} chunks ({total} rows) in {}ms",
683 start.elapsed().as_millis()
684 );
685
686 PlanResult::Success {
687 output: self.wrap_output(last_output),
688 duration_ms: start.elapsed().as_millis() as u64,
689 states: std::collections::HashMap::new(),
690 }
691 }
692
693 pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
695 match target {
696 somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
697 somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
698 }
699 }
700}
701
702#[cfg(test)]
703mod tests {
704 use super::*;
705 use somatize_compiler::ExecutionPlan;
706 use somatize_core::cache::CacheKey;
707 use somatize_core::error::Result as SomaResult;
708 use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
709 use somatize_core::value::Value;
710
711 struct TestDoubler;
712
713 impl Filter for TestDoubler {
714 fn config_hash(&self) -> CacheKey {
715 CacheKey::from_parts(&[b"TestDoubler"])
716 }
717 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
718 Ok(Value::Empty)
719 }
720 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
721 match x {
722 Value::Tensor { values, shape } => {
723 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
724 Ok(Value::tensor(doubled, shape.clone()))
725 }
726 _ => Ok(x.clone()),
727 }
728 }
729 fn meta(&self) -> FilterMeta {
730 FilterMeta {
731 name: "TestDoubler".into(),
732 kind: FilterKind::Stateless,
733 cacheable: true,
734 differentiable: true,
735 stream_mode: StreamMode::FixedState,
736 distribution: somatize_core::filter::Distribution::Local,
737 input_schema: None,
738 output_schema: None,
739 }
740 }
741 }
742
743 fn make_worker() -> Worker {
744 Worker::new(
745 "test_worker",
746 Capabilities {
747 cpu_cores: 4,
748 ram_bytes: 8_000_000_000,
749 gpus: vec![],
750 python_envs: vec![],
751 tags: vec!["cpu".into(), "test".into()],
752 },
753 )
754 }
755
756 #[test]
757 fn worker_registration() {
758 let worker = make_worker();
759 let msg = worker.registration_message();
760 if let WorkerToCoordinator::Register {
761 worker_id,
762 capabilities,
763 } = msg
764 {
765 assert_eq!(worker_id, "test_worker");
766 assert_eq!(capabilities.cpu_cores, 4);
767 } else {
768 panic!("wrong message type");
769 }
770 }
771
772 #[test]
773 fn worker_executes_plan_successfully() {
774 let mut worker = make_worker();
775 worker.register_filter("doubler", Box::new(TestDoubler));
776
777 let plan = SerializedPlan {
778 plan_id: "p_001".into(),
779 plan: ExecutionPlan::Execute {
780 node_id: "doubler".into(),
781 },
782 input: Some(crate::protocol::InputSource::Inline {
783 value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
784 }),
785 filters: vec![],
786 mode: ExecutionMode::default(),
787 metadata: serde_json::json!({}),
788 };
789
790 let result = worker.execute_plan(&plan);
791
792 if let PlanResult::Success {
793 output,
794 duration_ms,
795 ..
796 } = result
797 {
798 let value = match output {
799 OutputDelivery::Inline { value } => value,
800 _ => panic!("expected inline output"),
801 };
802 let (data, _) = value.as_tensor().unwrap();
803 assert_eq!(data, &[2.0, 4.0, 6.0]);
804 assert!(duration_ms < 1000);
805 } else {
806 panic!("expected success, got: {result:?}");
807 }
808 }
809
810 #[test]
811 fn worker_handles_missing_filter() {
812 let mut worker = make_worker();
813 let plan = SerializedPlan {
816 plan_id: "p_002".into(),
817 plan: ExecutionPlan::Execute {
818 node_id: "nonexistent".into(),
819 },
820 input: None,
821 filters: vec![],
822 mode: ExecutionMode::default(),
823 metadata: serde_json::json!({}),
824 };
825
826 let result = worker.execute_plan(&plan);
827 assert!(matches!(result, PlanResult::Failed { .. }));
828 }
829
830 #[test]
831 fn worker_matches_target_by_id() {
832 let worker = make_worker();
833 assert!(
834 worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
835 "test_worker".into()
836 ))
837 );
838 assert!(
839 !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
840 "other".into()
841 ))
842 );
843 }
844
845 #[test]
846 fn worker_matches_target_by_tag() {
847 let worker = make_worker();
848 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
849 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
850 assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
851 }
852
853 #[test]
854 fn worker_executes_sequence() {
855 let mut worker = make_worker();
856 worker.register_filter("d1", Box::new(TestDoubler));
857 worker.register_filter("d2", Box::new(TestDoubler));
858
859 let plan = SerializedPlan {
860 plan_id: "p_003".into(),
861 plan: ExecutionPlan::Sequence(vec![
862 ExecutionPlan::Execute {
863 node_id: "d1".into(),
864 },
865 ExecutionPlan::Execute {
866 node_id: "d2".into(),
867 },
868 ]),
869 input: Some(crate::protocol::InputSource::Inline {
870 value: Value::tensor(vec![5.0], vec![1]),
871 }),
872 filters: vec![],
873 mode: ExecutionMode::default(),
874 metadata: serde_json::json!({}),
875 };
876
877 let result = worker.execute_plan(&plan);
878 if let PlanResult::Success { output, .. } = result {
879 let value = match output {
880 OutputDelivery::Inline { value } => value,
881 _ => panic!("expected inline output"),
882 };
883 let (data, _) = value.as_tensor().unwrap();
884 assert_eq!(data, &[20.0]); } else {
886 panic!("expected success");
887 }
888 }
889
890 #[test]
891 fn worker_emits_events() {
892 let mut worker = make_worker();
893 worker.register_filter("doubler", Box::new(TestDoubler));
894 let mut rx = worker.subscribe();
895
896 let plan = SerializedPlan {
897 plan_id: "p_004".into(),
898 plan: ExecutionPlan::Execute {
899 node_id: "doubler".into(),
900 },
901 input: Some(crate::protocol::InputSource::Inline {
902 value: Value::tensor(vec![1.0], vec![1]),
903 }),
904 filters: vec![],
905 mode: ExecutionMode::default(),
906 metadata: serde_json::json!({}),
907 };
908
909 worker.execute_plan(&plan);
910
911 let mut events = Vec::new();
912 while let Ok(e) = rx.try_recv() {
913 events.push(e);
914 }
915 assert!(
916 events
917 .iter()
918 .any(|e| matches!(e, Event::NodeStarted { .. }))
919 );
920 assert!(
921 events
922 .iter()
923 .any(|e| matches!(e, Event::NodeCompleted { .. }))
924 );
925 }
926}