Skip to main content

somatize_worker/
worker.rs

1//! Worker — receives and executes plans from a coordinator.
2
3use crate::protocol::*;
4use somatize_core::cache::{CacheKey, CacheStore};
5use somatize_core::error::Result as SomaResult;
6use somatize_core::event::Event;
7use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
8use somatize_core::store::{DataStore, LocalDataStore};
9use somatize_core::value::Value;
10use somatize_runtime::{Context, EventBus, FilterLibrary, MemoryCache, execute};
11use std::sync::Arc;
12use std::time::Instant;
13
14/// A filter reconstructed from cloudpickle bytes.
15/// Deserializes the Python object on the worker and executes methods via subprocess.
16pub(crate) struct PickledFilterRunner {
17    /// cloudpickle.dumps() bytes of the original Python filter object.
18    pub(crate) pickled_bytes: Vec<u8>,
19    /// Node ID (for error messages).
20    pub(crate) node_id: String,
21    /// Path to the Python interpreter (venv or system).
22    pub(crate) python_path: String,
23    /// Pip requirements for retry-on-import-error.
24    pub(crate) requirements: Vec<String>,
25}
26
27impl Filter for PickledFilterRunner {
28    fn config_hash(&self) -> CacheKey {
29        CacheKey::from_parts(&[&self.pickled_bytes])
30    }
31
32    fn fit(&self, x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
33        self.run_python("fit", x)
34    }
35
36    fn forward(&self, x: &Value, state: &Value) -> SomaResult<Value> {
37        let input = if matches!(state, Value::Empty) {
38            x.clone()
39        } else {
40            Value::json(serde_json::json!({
41                "x": serde_json::to_value(x).unwrap_or_default(),
42                "state": serde_json::to_value(state).unwrap_or_default(),
43            }))
44        };
45        self.run_python("forward", &input)
46    }
47
48    fn meta(&self) -> FilterMeta {
49        FilterMeta {
50            name: self.node_id.clone(),
51            kind: FilterKind::Stateless,
52            cacheable: true,
53            differentiable: false,
54            stream_mode: StreamMode::FixedState,
55            distribution: somatize_core::filter::Distribution::Local,
56            input_schema: None,
57            output_schema: None,
58        }
59    }
60}
61
62impl PickledFilterRunner {
63    fn run_python(&self, method: &str, input: &Value) -> SomaResult<Value> {
64        self.run_python_with_retry(method, input, true)
65    }
66
67    fn run_python_with_retry(
68        &self,
69        method: &str,
70        input: &Value,
71        allow_retry: bool,
72    ) -> SomaResult<Value> {
73        use base64::engine::{Engine, general_purpose::STANDARD};
74        use std::io::Write;
75
76        let input_json = serde_json::to_string(input)
77            .map_err(|e| somatize_core::error::SomaError::Other(format!("serialize input: {e}")))?;
78        let pickled_b64 = STANDARD.encode(&self.pickled_bytes);
79
80        let script = format!(
81            r#"
82import json, sys, base64, cloudpickle
83
84def unwrap_value(v):
85    """Convert Soma Value JSON to native Python types."""
86    if isinstance(v, dict) and "type" in v and "data" in v:
87        t = v["type"]
88        d = v["data"]
89        if t == "Tensor":
90            return d.get("values", [])
91        if t == "Json":
92            return d
93        if t == "Empty":
94            return {{}}
95        if t == "Bytes":
96            return bytes(d)
97    return v
98
99pickled_b64 = sys.stdin.readline().strip()
100input_line = sys.stdin.read()
101
102pickled = base64.b64decode(pickled_b64)
103obj = cloudpickle.loads(pickled)
104raw = json.loads(input_line)
105input_data = unwrap_value(raw)
106
107if isinstance(input_data, dict) and "x" in input_data and "state" in input_data:
108    x = unwrap_value(input_data["x"])
109    state = unwrap_value(input_data["state"])
110    result = obj.{method}(x, state)
111else:
112    result = obj.{method}(input_data, {{}})
113
114print(json.dumps(result))
115"#,
116        );
117
118        let mut child = std::process::Command::new(&self.python_path)
119            .args(["-c", &script])
120            .stdin(std::process::Stdio::piped())
121            .stdout(std::process::Stdio::piped())
122            .stderr(std::process::Stdio::piped())
123            .spawn()
124            .map_err(|e| {
125                somatize_core::error::SomaError::Other(format!("python spawn failed: {e}"))
126            })?;
127
128        if let Some(mut stdin) = child.stdin.take() {
129            let _ = writeln!(stdin, "{pickled_b64}");
130            let _ = write!(stdin, "{input_json}");
131        }
132
133        let output = child.wait_with_output().map_err(|e| {
134            somatize_core::error::SomaError::Other(format!("python exec failed: {e}"))
135        })?;
136
137        if !output.status.success() {
138            let stderr = String::from_utf8_lossy(&output.stderr);
139
140            // Retry on ModuleNotFoundError: install known requirements + missing package
141            if allow_retry && stderr.contains("ModuleNotFoundError") {
142                let missing = Self::parse_missing_module(&stderr);
143                // Collect packages to install: known requirements + the missing one
144                let mut to_install: Vec<String> = self.requirements.clone();
145                if let Some(ref m) = missing
146                    && !to_install.iter().any(|r| r == m)
147                {
148                    to_install.push(m.clone());
149                }
150                if !to_install.is_empty() {
151                    let names = to_install.join(", ");
152                    tracing::warn!(
153                        "Missing module for filter '{}', installing: {names}",
154                        self.node_id
155                    );
156                    let mut args = vec!["-m", "pip", "install", "--quiet"];
157                    let refs: Vec<&str> = to_install.iter().map(|s| s.as_str()).collect();
158                    args.extend(refs);
159                    let install = std::process::Command::new(&self.python_path)
160                        .args(&args)
161                        .output();
162                    if let Ok(res) = install
163                        && res.status.success()
164                    {
165                        tracing::info!("Installed [{names}], retrying...");
166                        return self.run_python_with_retry(method, input, false);
167                    }
168                }
169            }
170
171            return Err(somatize_core::error::SomaError::Execution {
172                node_id: self.node_id.clone(),
173                message: format!("Python error: {stderr}"),
174            });
175        }
176
177        let stdout = String::from_utf8_lossy(&output.stdout);
178        let result: serde_json::Value = serde_json::from_str(stdout.trim()).map_err(|e| {
179            somatize_core::error::SomaError::Other(format!(
180                "parse python output: {e}\nstdout: {stdout}"
181            ))
182        })?;
183
184        if let Some(arr) = result.as_array() {
185            let values: Vec<f64> = arr.iter().filter_map(|v| v.as_f64()).collect();
186            if !values.is_empty() {
187                return Ok(Value::tensor(values.clone(), vec![values.len()]));
188            }
189        }
190
191        Ok(Value::json(result))
192    }
193
194    /// Parse "ModuleNotFoundError: No module named 'xxx'" from stderr.
195    fn parse_missing_module(stderr: &str) -> Option<String> {
196        for line in stderr.lines().rev() {
197            if line.contains("ModuleNotFoundError") {
198                // "ModuleNotFoundError: No module named 'xxx'"
199                if let Some(start) = line.find('\'') {
200                    let rest = &line[start + 1..];
201                    if let Some(end) = rest.find('\'') {
202                        return Some(rest[..end].split('.').next()?.to_string());
203                    }
204                }
205            }
206        }
207        None
208    }
209}
210
211/// Worker state: manages execution of plans received from a coordinator.
212pub struct Worker {
213    pub id: WorkerId,
214    pub capabilities: Capabilities,
215    event_bus: Arc<EventBus>,
216    cache: Arc<dyn CacheStore>,
217    filters: FilterLibrary,
218    /// Optional persistent DataStore (S3, Zarr, etc.) — configured by user.
219    data_store: Option<Arc<dyn DataStore>>,
220    /// Temporary local store for HTTP bulk uploads — auto-created, auto-cleaned.
221    temp_store: Arc<LocalDataStore>,
222    /// Environment manager for creating venvs with filter dependencies.
223    env_manager: crate::env_manager::EnvManager,
224}
225
226impl Worker {
227    pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
228        let worker_id: String = id.into();
229        let temp_path = std::env::temp_dir().join(format!("soma-uploads-{worker_id}"));
230        let temp_store = LocalDataStore::new(temp_path);
231        let env_path = std::env::temp_dir().join(format!("soma-envs-{worker_id}"));
232        Self {
233            id: worker_id,
234            capabilities,
235            event_bus: Arc::new(EventBus::new(256)),
236            cache: Arc::new(MemoryCache::default()),
237            filters: FilterLibrary::new(),
238            data_store: None,
239            temp_store: Arc::new(temp_store),
240            env_manager: crate::env_manager::EnvManager::new(
241                env_path,
242                crate::env_manager::EnvType::Venv,
243            ),
244        }
245    }
246
247    /// Set a custom cache store (e.g. tiered or shared).
248    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
249        self.cache = cache;
250        self
251    }
252
253    /// Set a persistent DataStore (S3, Zarr, etc.) for large data references.
254    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
255        self.data_store = Some(store);
256        self
257    }
258
259    /// Set a custom temp directory for HTTP bulk uploads.
260    pub fn with_temp_dir(mut self, path: std::path::PathBuf) -> Self {
261        self.temp_store = Arc::new(LocalDataStore::new(path));
262        self
263    }
264
265    /// Get the temp store (for HTTP upload endpoint).
266    pub fn temp_store(&self) -> &Arc<LocalDataStore> {
267        &self.temp_store
268    }
269
270    /// Register a filter that this worker can execute.
271    pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
272        self.filters.register(node_id, filter);
273    }
274
275    /// Get a filter by node_id (for stream executor construction).
276    pub fn get_filter(&self, node_id: &str) -> Option<Arc<dyn Filter>> {
277        self.filters.get(node_id)
278    }
279
280    /// Get trained state for a filter.
281    pub fn get_filter_state(&self, node_id: &str) -> Value {
282        self.filters
283            .get_state(node_id)
284            .cloned()
285            .unwrap_or(Value::Empty)
286    }
287
288    /// Set trained state for a filter.
289    pub fn set_filter_state(&mut self, node_id: &str, state: Value) {
290        self.filters.set_state(node_id, state);
291    }
292
293    /// Subscribe to execution events.
294    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
295        self.event_bus.subscribe()
296    }
297
298    /// Build a registration message.
299    pub fn registration_message(&self) -> WorkerToCoordinator {
300        WorkerToCoordinator::Register {
301            worker_id: self.id.clone(),
302            capabilities: self.capabilities.clone(),
303        }
304    }
305
306    /// Execute a serialized plan.
307    ///
308    /// If the plan contains serialized filter definitions, they are registered
309    /// temporarily for this execution (alongside any pre-registered filters).
310    ///
311    /// In **Fit** mode: fits each filter (topological order), stores trained states,
312    /// then forwards to propagate outputs. Returns states so the client can cache them.
313    ///
314    /// In **Forward** mode: executes the compiled plan directly.
315    pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
316        let start = Instant::now();
317
318        // Collect all requirements from serialized filters
319        let all_reqs: Vec<String> = plan
320            .filters
321            .iter()
322            .flat_map(|sf| sf.requirements.iter().cloned())
323            .collect::<std::collections::HashSet<_>>()
324            .into_iter()
325            .collect();
326
327        // Create/reuse venv if there are pip requirements, otherwise use system python
328        let python_path = if all_reqs.is_empty() {
329            "python3".to_string()
330        } else {
331            let reqs_str = all_reqs.join("\n");
332            match self.env_manager.ensure_env(&plan.plan_id, &reqs_str) {
333                Ok(path) => {
334                    tracing::info!("Using venv for plan {}: {:?}", plan.plan_id, path);
335                    path.to_string_lossy().to_string()
336                }
337                Err(e) => {
338                    tracing::warn!("Failed to create venv, falling back to system python: {e}");
339                    "python3".to_string()
340                }
341            }
342        };
343
344        // Register pickled filters (from remote client via cloudpickle)
345        for sf in &plan.filters {
346            let filter = Box::new(PickledFilterRunner {
347                pickled_bytes: sf.pickled_filter.clone(),
348                node_id: sf.node_id.clone(),
349                python_path: python_path.clone(),
350                requirements: sf.requirements.clone(),
351            });
352            self.filters.register(&sf.node_id, filter);
353            if let Some(state) = &sf.state {
354                self.filters.set_state(&sf.node_id, state.clone());
355            }
356        }
357
358        // Resolve input: inline, DataStore, or temp store (HTTP upload)
359        let input_value = plan.input.as_ref().map(|src| match src {
360            InputSource::Inline { value } => value.clone(),
361            InputSource::Reference { data_ref } => {
362                // Try persistent DataStore first, then temp store (HTTP uploads)
363                if let Some(store) = &self.data_store
364                    && let Ok(val) = store.get(data_ref)
365                {
366                    return val;
367                }
368                self.temp_store.get(data_ref).unwrap_or_else(|e| {
369                    tracing::warn!("Failed to resolve DataRef: {e}");
370                    Value::Empty
371                })
372            }
373        });
374
375        // DataStore-backed streaming: if input is a large DataRef and we have a store,
376        // read chunks via get_rows() and process with StreamExecutor (no full materialization).
377        if let Some(InputSource::Reference { data_ref }) = &plan.input
378            && let Some(store) = self.data_store.clone()
379            && let Ok(meta) = store.meta(data_ref)
380            && meta.total_rows > 1024
381        {
382            return self.execute_streamed_from_store(plan, &store, data_ref, &meta, start);
383        }
384
385        match &plan.mode {
386            ExecutionMode::Fit { y } => self.execute_fit(plan, input_value, y.as_ref(), start),
387            ExecutionMode::Forward => self.execute_forward(plan, input_value, start),
388        }
389    }
390
391    /// Forward mode: run the compiled execution plan.
392    fn execute_forward(
393        &mut self,
394        plan: &SerializedPlan,
395        input: Option<Value>,
396        start: Instant,
397    ) -> PlanResult {
398        let mut ctx = Context::new(
399            self.event_bus.clone(),
400            format!("worker_run_{}", plan.plan_id),
401        );
402
403        if let Some(val) = input {
404            ctx.set("input", val.clone());
405            // Also set per-root input
406            if let somatize_compiler::ExecutionPlan::Execute { node_id } = &plan.plan {
407                ctx.set(format!("__input_{node_id}"), val);
408            }
409        }
410
411        match execute(&plan.plan, &mut ctx, &self.filters, self.cache.as_ref()) {
412            Ok(()) => {
413                let output = ctx
414                    .execution_order
415                    .last()
416                    .and_then(|id| ctx.get(id))
417                    .cloned()
418                    .unwrap_or(Value::Empty);
419
420                PlanResult::Success {
421                    output,
422                    duration_ms: start.elapsed().as_millis() as u64,
423                    states: std::collections::HashMap::new(),
424                }
425            }
426            Err(e) => PlanResult::Failed {
427                error: e.to_string(),
428                duration_ms: start.elapsed().as_millis() as u64,
429            },
430        }
431    }
432
433    /// Fit mode: train each filter in topological order, return trained states.
434    fn execute_fit(
435        &mut self,
436        plan: &SerializedPlan,
437        input: Option<Value>,
438        y: Option<&Value>,
439        start: Instant,
440    ) -> PlanResult {
441        let run_id = format!("worker_fit_{}", plan.plan_id);
442        let x = input.unwrap_or(Value::Empty);
443
444        // Extract node execution order from plan
445        let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
446        let mut outputs: std::collections::HashMap<String, Value> =
447            std::collections::HashMap::new();
448        let mut trained_states: std::collections::HashMap<String, Value> =
449            std::collections::HashMap::new();
450
451        for node_id in &node_ids {
452            let filter = match self.filters.get(node_id) {
453                Some(f) => f,
454                None => {
455                    return PlanResult::Failed {
456                        error: format!("filter not found: {node_id}"),
457                        duration_ms: start.elapsed().as_millis() as u64,
458                    };
459                }
460            };
461
462            let meta = filter.meta();
463
464            self.event_bus.emit(Event::NodeStarted {
465                run_id: run_id.clone(),
466                node_id: node_id.to_string(),
467                kind: meta.kind,
468            });
469
470            let node_start = Instant::now();
471
472            // Resolve input: predecessor output or original input
473            let node_input = outputs
474                .values()
475                .last()
476                .cloned()
477                .unwrap_or_else(|| x.clone());
478
479            // Fit trainable filters, get/use state for forward
480            let state = if meta.kind == FilterKind::Trainable {
481                match filter.fit(&node_input, y) {
482                    Ok(s) => {
483                        self.filters.set_state(node_id, s.clone());
484                        trained_states.insert(node_id.clone(), s.clone());
485                        s
486                    }
487                    Err(e) => {
488                        return PlanResult::Failed {
489                            error: format!("fit({node_id}): {e}"),
490                            duration_ms: start.elapsed().as_millis() as u64,
491                        };
492                    }
493                }
494            } else {
495                self.filters
496                    .get_state(node_id)
497                    .cloned()
498                    .unwrap_or(Value::Empty)
499            };
500
501            // Forward with trained state
502            match filter.forward(&node_input, &state) {
503                Ok(output) => {
504                    self.event_bus.emit(Event::NodeCompleted {
505                        run_id: run_id.clone(),
506                        node_id: node_id.to_string(),
507                        duration: node_start.elapsed(),
508                        output_summary: format!("{output}"),
509                    });
510                    outputs.insert(node_id.clone(), output);
511                }
512                Err(e) => {
513                    return PlanResult::Failed {
514                        error: format!("forward({node_id}): {e}"),
515                        duration_ms: start.elapsed().as_millis() as u64,
516                    };
517                }
518            }
519        }
520
521        let output = outputs.values().last().cloned().unwrap_or(Value::Empty);
522
523        PlanResult::Success {
524            output,
525            duration_ms: start.elapsed().as_millis() as u64,
526            states: trained_states,
527        }
528    }
529
530    /// DataStore-backed streaming: read chunks via get_rows(), process with StreamExecutor.
531    /// Avoids loading the entire dataset into memory.
532    fn execute_streamed_from_store(
533        &mut self,
534        plan: &SerializedPlan,
535        store: &Arc<dyn DataStore>,
536        data_ref: &somatize_core::store::DataRef,
537        meta: &somatize_core::store::StoreMeta,
538        start: Instant,
539    ) -> PlanResult {
540        use somatize_runtime::stream::{FittedFilter, StreamExecutor};
541
542        let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
543        let fitted: Vec<FittedFilter> = node_ids
544            .iter()
545            .filter_map(|id| {
546                let filter = self.filters.get(id)?;
547                let state = self.filters.get_state(id).cloned().unwrap_or(Value::Empty);
548                Some(FittedFilter {
549                    name: id.clone(),
550                    filter,
551                    state,
552                })
553            })
554            .collect();
555
556        let mut executor = StreamExecutor::new(fitted);
557        let chunk_size = 1024;
558        let run_id = format!("worker_stream_{}", plan.plan_id);
559
560        self.event_bus.emit(Event::RunStarted {
561            run_id: run_id.clone(),
562            plan_summary: somatize_core::event::PlanSummary {
563                total_nodes: node_ids.len(),
564                cached_nodes: 0,
565                parallel_branches: 0,
566            },
567        });
568
569        let mut last_output = Value::Empty;
570        let total = meta.total_rows;
571        let mut chunk_idx = 0;
572
573        for row_start in (0..total).step_by(chunk_size) {
574            let len = chunk_size.min(total - row_start);
575            let chunk = match store.get_rows(data_ref, row_start, len) {
576                Ok(c) => c,
577                Err(e) => {
578                    return PlanResult::Failed {
579                        error: format!("get_rows({row_start}..{}): {e}", row_start + len),
580                        duration_ms: start.elapsed().as_millis() as u64,
581                    };
582                }
583            };
584
585            match executor.process_chunk(chunk) {
586                Ok(Some(output)) => last_output = output,
587                Ok(None) => {} // Barrier — accumulating
588                Err(e) => {
589                    return PlanResult::Failed {
590                        error: format!("stream chunk {chunk_idx}: {e}"),
591                        duration_ms: start.elapsed().as_millis() as u64,
592                    };
593                }
594            }
595            chunk_idx += 1;
596        }
597
598        // Flush barrier filters
599        match executor.flush() {
600            Ok(Some(output)) => last_output = output,
601            Ok(None) => {}
602            Err(e) => {
603                return PlanResult::Failed {
604                    error: format!("stream flush: {e}"),
605                    duration_ms: start.elapsed().as_millis() as u64,
606                };
607            }
608        }
609
610        tracing::info!(
611            "Streamed {chunk_idx} chunks ({total} rows) in {}ms",
612            start.elapsed().as_millis()
613        );
614
615        PlanResult::Success {
616            output: last_output,
617            duration_ms: start.elapsed().as_millis() as u64,
618            states: std::collections::HashMap::new(),
619        }
620    }
621
622    /// Check if this worker matches a remote target.
623    pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
624        match target {
625            somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
626            somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
627        }
628    }
629}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634    use somatize_compiler::ExecutionPlan;
635    use somatize_core::cache::CacheKey;
636    use somatize_core::error::Result as SomaResult;
637    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
638    use somatize_core::value::Value;
639
640    struct TestDoubler;
641
642    impl Filter for TestDoubler {
643        fn config_hash(&self) -> CacheKey {
644            CacheKey::from_parts(&[b"TestDoubler"])
645        }
646        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
647            Ok(Value::Empty)
648        }
649        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
650            match x {
651                Value::Tensor { values, shape } => {
652                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
653                    Ok(Value::tensor(doubled, shape.clone()))
654                }
655                _ => Ok(x.clone()),
656            }
657        }
658        fn meta(&self) -> FilterMeta {
659            FilterMeta {
660                name: "TestDoubler".into(),
661                kind: FilterKind::Stateless,
662                cacheable: true,
663                differentiable: true,
664                stream_mode: StreamMode::FixedState,
665                distribution: somatize_core::filter::Distribution::Local,
666                input_schema: None,
667                output_schema: None,
668            }
669        }
670    }
671
672    fn make_worker() -> Worker {
673        Worker::new(
674            "test_worker",
675            Capabilities {
676                cpu_cores: 4,
677                ram_bytes: 8_000_000_000,
678                gpus: vec![],
679                python_envs: vec![],
680                tags: vec!["cpu".into(), "test".into()],
681            },
682        )
683    }
684
685    #[test]
686    fn worker_registration() {
687        let worker = make_worker();
688        let msg = worker.registration_message();
689        if let WorkerToCoordinator::Register {
690            worker_id,
691            capabilities,
692        } = msg
693        {
694            assert_eq!(worker_id, "test_worker");
695            assert_eq!(capabilities.cpu_cores, 4);
696        } else {
697            panic!("wrong message type");
698        }
699    }
700
701    #[test]
702    fn worker_executes_plan_successfully() {
703        let mut worker = make_worker();
704        worker.register_filter("doubler", Box::new(TestDoubler));
705
706        let plan = SerializedPlan {
707            plan_id: "p_001".into(),
708            plan: ExecutionPlan::Execute {
709                node_id: "doubler".into(),
710            },
711            input: Some(crate::protocol::InputSource::Inline {
712                value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
713            }),
714            filters: vec![],
715            mode: ExecutionMode::default(),
716            metadata: serde_json::json!({}),
717        };
718
719        let result = worker.execute_plan(&plan);
720
721        if let PlanResult::Success {
722            output,
723            duration_ms,
724            ..
725        } = result
726        {
727            let (data, _) = output.as_tensor().unwrap();
728            assert_eq!(data, &[2.0, 4.0, 6.0]);
729            assert!(duration_ms < 1000);
730        } else {
731            panic!("expected success, got: {result:?}");
732        }
733    }
734
735    #[test]
736    fn worker_handles_missing_filter() {
737        let mut worker = make_worker();
738        // Don't register any filters
739
740        let plan = SerializedPlan {
741            plan_id: "p_002".into(),
742            plan: ExecutionPlan::Execute {
743                node_id: "nonexistent".into(),
744            },
745            input: None,
746            filters: vec![],
747            mode: ExecutionMode::default(),
748            metadata: serde_json::json!({}),
749        };
750
751        let result = worker.execute_plan(&plan);
752        assert!(matches!(result, PlanResult::Failed { .. }));
753    }
754
755    #[test]
756    fn worker_matches_target_by_id() {
757        let worker = make_worker();
758        assert!(
759            worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
760                "test_worker".into()
761            ))
762        );
763        assert!(
764            !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
765                "other".into()
766            ))
767        );
768    }
769
770    #[test]
771    fn worker_matches_target_by_tag() {
772        let worker = make_worker();
773        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
774        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
775        assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
776    }
777
778    #[test]
779    fn worker_executes_sequence() {
780        let mut worker = make_worker();
781        worker.register_filter("d1", Box::new(TestDoubler));
782        worker.register_filter("d2", Box::new(TestDoubler));
783
784        let plan = SerializedPlan {
785            plan_id: "p_003".into(),
786            plan: ExecutionPlan::Sequence(vec![
787                ExecutionPlan::Execute {
788                    node_id: "d1".into(),
789                },
790                ExecutionPlan::Execute {
791                    node_id: "d2".into(),
792                },
793            ]),
794            input: Some(crate::protocol::InputSource::Inline {
795                value: Value::tensor(vec![5.0], vec![1]),
796            }),
797            filters: vec![],
798            mode: ExecutionMode::default(),
799            metadata: serde_json::json!({}),
800        };
801
802        let result = worker.execute_plan(&plan);
803        if let PlanResult::Success { output, .. } = result {
804            let (data, _) = output.as_tensor().unwrap();
805            assert_eq!(data, &[20.0]); // 5 * 2 * 2
806        } else {
807            panic!("expected success");
808        }
809    }
810
811    #[test]
812    fn worker_emits_events() {
813        let mut worker = make_worker();
814        worker.register_filter("doubler", Box::new(TestDoubler));
815        let mut rx = worker.subscribe();
816
817        let plan = SerializedPlan {
818            plan_id: "p_004".into(),
819            plan: ExecutionPlan::Execute {
820                node_id: "doubler".into(),
821            },
822            input: Some(crate::protocol::InputSource::Inline {
823                value: Value::tensor(vec![1.0], vec![1]),
824            }),
825            filters: vec![],
826            mode: ExecutionMode::default(),
827            metadata: serde_json::json!({}),
828        };
829
830        worker.execute_plan(&plan);
831
832        let mut events = Vec::new();
833        while let Ok(e) = rx.try_recv() {
834            events.push(e);
835        }
836        assert!(
837            events
838                .iter()
839                .any(|e| matches!(e, Event::NodeStarted { .. }))
840        );
841        assert!(
842            events
843                .iter()
844                .any(|e| matches!(e, Event::NodeCompleted { .. }))
845        );
846    }
847}