Skip to main content

somatize_worker/
worker.rs

1//! Worker — receives and executes plans from a coordinator.
2
3use crate::protocol::*;
4use somatize_core::cache::{CacheKey, CacheStore};
5use somatize_core::error::Result as SomaResult;
6use somatize_core::event::Event;
7use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
8use somatize_core::store::{DataStore, LocalDataStore};
9use somatize_core::value::Value;
10use somatize_runtime::{Context, EventBus, FilterLibrary, MemoryCache, execute};
11use std::sync::Arc;
12use std::time::Instant;
13
14/// A filter reconstructed from cloudpickle bytes.
15/// Deserializes the Python object on the worker and executes methods via subprocess.
16pub(crate) struct PickledFilterRunner {
17    /// cloudpickle.dumps() bytes of the original Python filter object.
18    pub(crate) pickled_bytes: Vec<u8>,
19    /// Node ID (for error messages).
20    pub(crate) node_id: String,
21    /// Path to the Python interpreter (venv or system).
22    pub(crate) python_path: String,
23    /// Pip requirements for retry-on-import-error.
24    pub(crate) requirements: Vec<String>,
25    /// Whether this filter is trainable (has meaningful fit()).
26    pub(crate) trainable: bool,
27}
28
29impl Filter for PickledFilterRunner {
30    fn config_hash(&self) -> CacheKey {
31        CacheKey::from_parts(&[&self.pickled_bytes])
32    }
33
34    fn fit(&self, x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
35        self.run_python("fit", x)
36    }
37
38    fn forward(&self, x: &Value, state: &Value) -> SomaResult<Value> {
39        let input = if matches!(state, Value::Empty) {
40            x.clone()
41        } else {
42            Value::json(serde_json::json!({
43                "x": serde_json::to_value(x).unwrap_or_default(),
44                "state": serde_json::to_value(state).unwrap_or_default(),
45            }))
46        };
47        self.run_python("forward", &input)
48    }
49
50    fn meta(&self) -> FilterMeta {
51        FilterMeta {
52            name: self.node_id.clone(),
53            kind: if self.trainable {
54                FilterKind::Trainable
55            } else {
56                FilterKind::Stateless
57            },
58            cacheable: true,
59            differentiable: false,
60            stream_mode: StreamMode::FixedState,
61            distribution: somatize_core::filter::Distribution::Local,
62            input_schema: None,
63            output_schema: None,
64        }
65    }
66}
67
68impl PickledFilterRunner {
69    fn run_python(&self, method: &str, input: &Value) -> SomaResult<Value> {
70        self.run_python_with_retry(method, input, true)
71    }
72
73    fn run_python_with_retry(
74        &self,
75        method: &str,
76        input: &Value,
77        allow_retry: bool,
78    ) -> SomaResult<Value> {
79        use base64::engine::{Engine, general_purpose::STANDARD};
80        use std::io::Write;
81
82        let input_json = serde_json::to_string(input)
83            .map_err(|e| somatize_core::error::SomaError::Other(format!("serialize input: {e}")))?;
84        let pickled_b64 = STANDARD.encode(&self.pickled_bytes);
85
86        let script = format!(
87            r#"
88import json, sys, base64, cloudpickle
89
90def unwrap_value(v):
91    """Convert Soma Value JSON to native Python types."""
92    if isinstance(v, dict) and "type" in v and "data" in v:
93        t = v["type"]
94        d = v["data"]
95        if t == "Tensor":
96            return d.get("values", [])
97        if t == "Json":
98            return d
99        if t == "Empty":
100            return {{}}
101        if t == "Bytes":
102            return bytes(d)
103    return v
104
105pickled_b64 = sys.stdin.readline().strip()
106input_line = sys.stdin.read()
107
108pickled = base64.b64decode(pickled_b64)
109obj = cloudpickle.loads(pickled)
110raw = json.loads(input_line)
111input_data = unwrap_value(raw)
112
113if isinstance(input_data, dict) and "x" in input_data and "state" in input_data:
114    x = unwrap_value(input_data["x"])
115    state = unwrap_value(input_data["state"])
116    result = obj.{method}(x, state)
117else:
118    result = obj.{method}(input_data, {{}})
119
120print(json.dumps(result))
121"#,
122        );
123
124        let mut child = std::process::Command::new(&self.python_path)
125            .args(["-c", &script])
126            .stdin(std::process::Stdio::piped())
127            .stdout(std::process::Stdio::piped())
128            .stderr(std::process::Stdio::piped())
129            .spawn()
130            .map_err(|e| {
131                somatize_core::error::SomaError::Other(format!("python spawn failed: {e}"))
132            })?;
133
134        if let Some(mut stdin) = child.stdin.take() {
135            let _ = writeln!(stdin, "{pickled_b64}");
136            let _ = write!(stdin, "{input_json}");
137        }
138
139        let output = child.wait_with_output().map_err(|e| {
140            somatize_core::error::SomaError::Other(format!("python exec failed: {e}"))
141        })?;
142
143        if !output.status.success() {
144            let stderr = String::from_utf8_lossy(&output.stderr);
145
146            // Retry on ModuleNotFoundError: install known requirements + missing package
147            if allow_retry && stderr.contains("ModuleNotFoundError") {
148                let missing = Self::parse_missing_module(&stderr);
149                // Collect packages to install: known requirements + the missing one
150                let mut to_install: Vec<String> = self.requirements.clone();
151                if let Some(ref m) = missing
152                    && !to_install.iter().any(|r| r == m)
153                {
154                    to_install.push(m.clone());
155                }
156                if !to_install.is_empty() {
157                    let names = to_install.join(", ");
158                    tracing::warn!(
159                        "Missing module for filter '{}', installing: {names}",
160                        self.node_id
161                    );
162                    let mut args = vec!["-m", "pip", "install", "--quiet"];
163                    let refs: Vec<&str> = to_install.iter().map(|s| s.as_str()).collect();
164                    args.extend(refs);
165                    let install = std::process::Command::new(&self.python_path)
166                        .args(&args)
167                        .output();
168                    if let Ok(res) = install
169                        && res.status.success()
170                    {
171                        tracing::info!("Installed [{names}], retrying...");
172                        return self.run_python_with_retry(method, input, false);
173                    }
174                }
175            }
176
177            return Err(somatize_core::error::SomaError::Execution {
178                node_id: self.node_id.clone(),
179                message: format!("Python error: {stderr}"),
180            });
181        }
182
183        let stdout = String::from_utf8_lossy(&output.stdout);
184        let result: serde_json::Value = serde_json::from_str(stdout.trim()).map_err(|e| {
185            somatize_core::error::SomaError::Other(format!(
186                "parse python output: {e}\nstdout: {stdout}"
187            ))
188        })?;
189
190        if let Some(arr) = result.as_array() {
191            let values: Vec<f64> = arr.iter().filter_map(|v| v.as_f64()).collect();
192            if !values.is_empty() {
193                return Ok(Value::tensor(values.clone(), vec![values.len()]));
194            }
195        }
196
197        Ok(Value::json(result))
198    }
199
200    /// Parse "ModuleNotFoundError: No module named 'xxx'" from stderr.
201    fn parse_missing_module(stderr: &str) -> Option<String> {
202        for line in stderr.lines().rev() {
203            if line.contains("ModuleNotFoundError") {
204                // "ModuleNotFoundError: No module named 'xxx'"
205                if let Some(start) = line.find('\'') {
206                    let rest = &line[start + 1..];
207                    if let Some(end) = rest.find('\'') {
208                        return Some(rest[..end].split('.').next()?.to_string());
209                    }
210                }
211            }
212        }
213        None
214    }
215}
216
217/// Worker state: manages execution of plans received from a coordinator.
218pub struct Worker {
219    pub id: WorkerId,
220    pub capabilities: Capabilities,
221    event_bus: Arc<EventBus>,
222    cache: Arc<dyn CacheStore>,
223    filters: FilterLibrary,
224    /// Optional persistent DataStore (S3, Zarr, etc.) — configured by user.
225    data_store: Option<Arc<dyn DataStore>>,
226    /// Temporary local store for HTTP bulk uploads — auto-created, auto-cleaned.
227    temp_store: Arc<LocalDataStore>,
228    /// Environment manager for creating venvs with filter dependencies.
229    env_manager: crate::env_manager::EnvManager,
230}
231
232impl Worker {
233    pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
234        let worker_id: String = id.into();
235        let temp_path = std::env::temp_dir().join(format!("soma-uploads-{worker_id}"));
236        let temp_store = LocalDataStore::new(temp_path);
237        let env_path = std::env::temp_dir().join(format!("soma-envs-{worker_id}"));
238        Self {
239            id: worker_id,
240            capabilities,
241            event_bus: Arc::new(EventBus::new(256)),
242            cache: Arc::new(MemoryCache::default()),
243            filters: FilterLibrary::new(),
244            data_store: None,
245            temp_store: Arc::new(temp_store),
246            env_manager: crate::env_manager::EnvManager::new(
247                env_path,
248                crate::env_manager::EnvType::Venv,
249            ),
250        }
251    }
252
253    /// Set a custom cache store (e.g. tiered or shared).
254    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
255        self.cache = cache;
256        self
257    }
258
259    /// Set a persistent DataStore (S3, Zarr, etc.) for large data references.
260    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
261        self.data_store = Some(store);
262        self
263    }
264
265    /// Set a custom temp directory for HTTP bulk uploads.
266    pub fn with_temp_dir(mut self, path: std::path::PathBuf) -> Self {
267        self.temp_store = Arc::new(LocalDataStore::new(path));
268        self
269    }
270
271    /// Get the temp store (for HTTP upload endpoint).
272    pub fn temp_store(&self) -> &Arc<LocalDataStore> {
273        &self.temp_store
274    }
275
276    /// Register a filter that this worker can execute.
277    pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
278        self.filters.register(node_id, filter);
279    }
280
281    /// Get a filter by node_id (for stream executor construction).
282    pub fn get_filter(&self, node_id: &str) -> Option<Arc<dyn Filter>> {
283        self.filters.get(node_id)
284    }
285
286    /// Get trained state for a filter.
287    pub fn get_filter_state(&self, node_id: &str) -> Value {
288        self.filters
289            .get_state(node_id)
290            .cloned()
291            .unwrap_or(Value::Empty)
292    }
293
294    /// Set trained state for a filter.
295    pub fn set_filter_state(&mut self, node_id: &str, state: Value) {
296        self.filters.set_state(node_id, state);
297    }
298
299    /// Wrap output in the right delivery: inline for small, DataRef for large.
300    pub fn wrap_output(&self, output: Value) -> OutputDelivery {
301        let size = serde_json::to_vec(&output).map(|v| v.len()).unwrap_or(0);
302        if size >= somatize_core::store::INLINE_THRESHOLD_BYTES {
303            let key = somatize_core::cache::CacheKey::hash_data(
304                &serde_json::to_vec(&output).unwrap_or_default(),
305            );
306            if let Ok(data_ref) = self.temp_store.put(&key, &output) {
307                return OutputDelivery::Reference { data_ref };
308            }
309        }
310        OutputDelivery::Inline { value: output }
311    }
312
313    /// Subscribe to execution events.
314    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
315        self.event_bus.subscribe()
316    }
317
318    /// Build a registration message.
319    pub fn registration_message(&self) -> WorkerToCoordinator {
320        WorkerToCoordinator::Register {
321            worker_id: self.id.clone(),
322            capabilities: self.capabilities.clone(),
323        }
324    }
325
326    /// Execute a serialized plan.
327    ///
328    /// If the plan contains serialized filter definitions, they are registered
329    /// temporarily for this execution (alongside any pre-registered filters).
330    ///
331    /// In **Fit** mode: fits each filter (topological order), stores trained states,
332    /// then forwards to propagate outputs. Returns states so the client can cache them.
333    ///
334    /// In **Forward** mode: executes the compiled plan directly.
335    pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
336        let start = Instant::now();
337
338        // Collect all requirements from serialized filters
339        let all_reqs: Vec<String> = plan
340            .filters
341            .iter()
342            .flat_map(|sf| sf.requirements.iter().cloned())
343            .collect::<std::collections::HashSet<_>>()
344            .into_iter()
345            .collect();
346
347        // Create/reuse venv if there are pip requirements, otherwise use system python
348        let python_path = if all_reqs.is_empty() {
349            "python3".to_string()
350        } else {
351            let reqs_str = all_reqs.join("\n");
352            match self.env_manager.ensure_env(&plan.plan_id, &reqs_str) {
353                Ok(path) => {
354                    tracing::info!("Using venv for plan {}: {:?}", plan.plan_id, path);
355                    path.to_string_lossy().to_string()
356                }
357                Err(e) => {
358                    tracing::warn!("Failed to create venv, falling back to system python: {e}");
359                    "python3".to_string()
360                }
361            }
362        };
363
364        // Resolve venv site-packages path for EmbeddedPyFilter
365        #[cfg(feature = "embedded-python")]
366        let site_packages = if python_path != "python3" {
367            // Venv python path is like .../bin/python → site-packages is ../lib/pythonX.Y/site-packages
368            let venv_dir = std::path::Path::new(&python_path)
369                .parent()
370                .and_then(|bin| bin.parent());
371            venv_dir.and_then(|d| {
372                std::fs::read_dir(d.join("lib"))
373                    .ok()?
374                    .filter_map(|e| e.ok())
375                    .find(|e| e.file_name().to_string_lossy().starts_with("python"))
376                    .map(|e| e.path().join("site-packages").to_string_lossy().to_string())
377            })
378        } else {
379            None
380        };
381
382        // Register filters: prefer EmbeddedPyFilter (PyO3) when available
383        for sf in &plan.filters {
384            let filter: Box<dyn Filter> = {
385                #[cfg(feature = "embedded-python")]
386                {
387                    match crate::py_filter::EmbeddedPyFilter::new(
388                        &sf.pickled_filter,
389                        sf.node_id.clone(),
390                        sf.trainable,
391                        site_packages.as_deref(),
392                    ) {
393                        Ok(embedded) => {
394                            tracing::info!("Using embedded PyO3 filter for '{}'", sf.node_id);
395                            Box::new(embedded)
396                        }
397                        Err(e) => {
398                            tracing::warn!(
399                                "PyO3 embed failed for '{}': {e}, falling back to subprocess",
400                                sf.node_id
401                            );
402                            Box::new(PickledFilterRunner {
403                                pickled_bytes: sf.pickled_filter.clone(),
404                                node_id: sf.node_id.clone(),
405                                python_path: python_path.clone(),
406                                requirements: sf.requirements.clone(),
407                                trainable: sf.trainable,
408                            })
409                        }
410                    }
411                }
412                #[cfg(not(feature = "embedded-python"))]
413                {
414                    Box::new(PickledFilterRunner {
415                        pickled_bytes: sf.pickled_filter.clone(),
416                        node_id: sf.node_id.clone(),
417                        python_path: python_path.clone(),
418                        requirements: sf.requirements.clone(),
419                        trainable: sf.trainable,
420                    })
421                }
422            };
423            self.filters.register(&sf.node_id, filter);
424            if let Some(state) = &sf.state {
425                self.filters.set_state(&sf.node_id, state.clone());
426            }
427        }
428
429        // Resolve input: inline, DataStore, or temp store (HTTP upload)
430        let input_value = plan.input.as_ref().map(|src| match src {
431            InputSource::Inline { value } => value.clone(),
432            InputSource::Reference { data_ref } => {
433                // Try persistent DataStore first, then temp store (HTTP uploads)
434                if let Some(store) = &self.data_store
435                    && let Ok(val) = store.get(data_ref)
436                {
437                    return val;
438                }
439                self.temp_store.get(data_ref).unwrap_or_else(|e| {
440                    tracing::warn!("Failed to resolve DataRef: {e}");
441                    Value::Empty
442                })
443            }
444        });
445
446        // DataStore-backed streaming: if input is a large DataRef and we have a store,
447        // read chunks via get_rows() and process with StreamExecutor (no full materialization).
448        if let Some(InputSource::Reference { data_ref }) = &plan.input
449            && let Some(store) = self.data_store.clone()
450            && let Ok(meta) = store.meta(data_ref)
451            && meta.total_rows > 1024
452        {
453            return self.execute_streamed_from_store(plan, &store, data_ref, &meta, start);
454        }
455
456        match &plan.mode {
457            ExecutionMode::Fit { y } => self.execute_fit(plan, input_value, y.as_ref(), start),
458            ExecutionMode::Forward => self.execute_forward(plan, input_value, start),
459        }
460    }
461
462    /// Forward mode: run the compiled execution plan.
463    fn execute_forward(
464        &mut self,
465        plan: &SerializedPlan,
466        input: Option<Value>,
467        start: Instant,
468    ) -> PlanResult {
469        let mut ctx = Context::new(
470            self.event_bus.clone(),
471            format!("worker_run_{}", plan.plan_id),
472        );
473
474        if let Some(val) = input {
475            ctx.set("input", val.clone());
476            // Also set per-root input
477            if let somatize_compiler::ExecutionPlan::Execute { node_id } = &plan.plan {
478                ctx.set(format!("__input_{node_id}"), val);
479            }
480        }
481
482        match execute(&plan.plan, &mut ctx, &self.filters, self.cache.as_ref()) {
483            Ok(()) => {
484                let output = ctx
485                    .execution_order
486                    .last()
487                    .and_then(|id| ctx.get(id))
488                    .cloned()
489                    .unwrap_or(Value::Empty);
490
491                PlanResult::Success {
492                    output: self.wrap_output(output),
493                    duration_ms: start.elapsed().as_millis() as u64,
494                    states: std::collections::HashMap::new(),
495                }
496            }
497            Err(e) => PlanResult::Failed {
498                error: e.to_string(),
499                duration_ms: start.elapsed().as_millis() as u64,
500            },
501        }
502    }
503
504    /// Fit mode: train each filter in topological order, return trained states.
505    fn execute_fit(
506        &mut self,
507        plan: &SerializedPlan,
508        input: Option<Value>,
509        y: Option<&Value>,
510        start: Instant,
511    ) -> PlanResult {
512        let run_id = format!("worker_fit_{}", plan.plan_id);
513        let x = input.unwrap_or(Value::Empty);
514
515        // Extract node execution order from plan
516        let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
517        let mut outputs: std::collections::HashMap<String, Value> =
518            std::collections::HashMap::new();
519        let mut trained_states: std::collections::HashMap<String, Value> =
520            std::collections::HashMap::new();
521
522        for node_id in &node_ids {
523            let filter = match self.filters.get(node_id) {
524                Some(f) => f,
525                None => {
526                    return PlanResult::Failed {
527                        error: format!("filter not found: {node_id}"),
528                        duration_ms: start.elapsed().as_millis() as u64,
529                    };
530                }
531            };
532
533            let meta = filter.meta();
534
535            self.event_bus.emit(Event::NodeStarted {
536                run_id: run_id.clone(),
537                node_id: node_id.to_string(),
538                kind: meta.kind,
539            });
540
541            let node_start = Instant::now();
542
543            // Resolve input: predecessor output or original input
544            let node_input = outputs
545                .values()
546                .last()
547                .cloned()
548                .unwrap_or_else(|| x.clone());
549
550            // Fit trainable filters, get/use state for forward
551            let state = if meta.kind == FilterKind::Trainable {
552                match filter.fit(&node_input, y) {
553                    Ok(s) => {
554                        self.filters.set_state(node_id, s.clone());
555                        trained_states.insert(node_id.clone(), s.clone());
556                        s
557                    }
558                    Err(e) => {
559                        return PlanResult::Failed {
560                            error: format!("fit({node_id}): {e}"),
561                            duration_ms: start.elapsed().as_millis() as u64,
562                        };
563                    }
564                }
565            } else {
566                self.filters
567                    .get_state(node_id)
568                    .cloned()
569                    .unwrap_or(Value::Empty)
570            };
571
572            // Forward with trained state
573            match filter.forward(&node_input, &state) {
574                Ok(output) => {
575                    self.event_bus.emit(Event::NodeCompleted {
576                        run_id: run_id.clone(),
577                        node_id: node_id.to_string(),
578                        duration: node_start.elapsed(),
579                        output_summary: format!("{output}"),
580                    });
581                    outputs.insert(node_id.clone(), output);
582                }
583                Err(e) => {
584                    return PlanResult::Failed {
585                        error: format!("forward({node_id}): {e}"),
586                        duration_ms: start.elapsed().as_millis() as u64,
587                    };
588                }
589            }
590        }
591
592        let output = outputs.values().last().cloned().unwrap_or(Value::Empty);
593
594        PlanResult::Success {
595            output: self.wrap_output(output),
596            duration_ms: start.elapsed().as_millis() as u64,
597            states: trained_states,
598        }
599    }
600
601    /// DataStore-backed streaming: read chunks via get_rows(), process with StreamExecutor.
602    /// Avoids loading the entire dataset into memory.
603    fn execute_streamed_from_store(
604        &mut self,
605        plan: &SerializedPlan,
606        store: &Arc<dyn DataStore>,
607        data_ref: &somatize_core::store::DataRef,
608        meta: &somatize_core::store::StoreMeta,
609        start: Instant,
610    ) -> PlanResult {
611        use somatize_runtime::stream::{FittedFilter, StreamExecutor};
612
613        let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
614        let fitted: Vec<FittedFilter> = node_ids
615            .iter()
616            .filter_map(|id| {
617                let filter = self.filters.get(id)?;
618                let state = self.filters.get_state(id).cloned().unwrap_or(Value::Empty);
619                Some(FittedFilter {
620                    name: id.clone(),
621                    filter,
622                    state,
623                })
624            })
625            .collect();
626
627        let mut executor = StreamExecutor::new(fitted);
628        let chunk_size = 1024;
629        let run_id = format!("worker_stream_{}", plan.plan_id);
630
631        self.event_bus.emit(Event::RunStarted {
632            run_id: run_id.clone(),
633            plan_summary: somatize_core::event::PlanSummary {
634                total_nodes: node_ids.len(),
635                cached_nodes: 0,
636                parallel_branches: 0,
637            },
638        });
639
640        let mut last_output = Value::Empty;
641        let total = meta.total_rows;
642        let mut chunk_idx = 0;
643
644        for row_start in (0..total).step_by(chunk_size) {
645            let len = chunk_size.min(total - row_start);
646            let chunk = match store.get_rows(data_ref, row_start, len) {
647                Ok(c) => c,
648                Err(e) => {
649                    return PlanResult::Failed {
650                        error: format!("get_rows({row_start}..{}): {e}", row_start + len),
651                        duration_ms: start.elapsed().as_millis() as u64,
652                    };
653                }
654            };
655
656            match executor.process_chunk(chunk) {
657                Ok(Some(output)) => last_output = output,
658                Ok(None) => {} // Barrier — accumulating
659                Err(e) => {
660                    return PlanResult::Failed {
661                        error: format!("stream chunk {chunk_idx}: {e}"),
662                        duration_ms: start.elapsed().as_millis() as u64,
663                    };
664                }
665            }
666            chunk_idx += 1;
667        }
668
669        // Flush barrier filters
670        match executor.flush() {
671            Ok(Some(output)) => last_output = output,
672            Ok(None) => {}
673            Err(e) => {
674                return PlanResult::Failed {
675                    error: format!("stream flush: {e}"),
676                    duration_ms: start.elapsed().as_millis() as u64,
677                };
678            }
679        }
680
681        tracing::info!(
682            "Streamed {chunk_idx} chunks ({total} rows) in {}ms",
683            start.elapsed().as_millis()
684        );
685
686        PlanResult::Success {
687            output: self.wrap_output(last_output),
688            duration_ms: start.elapsed().as_millis() as u64,
689            states: std::collections::HashMap::new(),
690        }
691    }
692
693    /// Check if this worker matches a remote target.
694    pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
695        match target {
696            somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
697            somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
698        }
699    }
700}
701
702#[cfg(test)]
703mod tests {
704    use super::*;
705    use somatize_compiler::ExecutionPlan;
706    use somatize_core::cache::CacheKey;
707    use somatize_core::error::Result as SomaResult;
708    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
709    use somatize_core::value::Value;
710
711    struct TestDoubler;
712
713    impl Filter for TestDoubler {
714        fn config_hash(&self) -> CacheKey {
715            CacheKey::from_parts(&[b"TestDoubler"])
716        }
717        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
718            Ok(Value::Empty)
719        }
720        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
721            match x {
722                Value::Tensor { values, shape } => {
723                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
724                    Ok(Value::tensor(doubled, shape.clone()))
725                }
726                _ => Ok(x.clone()),
727            }
728        }
729        fn meta(&self) -> FilterMeta {
730            FilterMeta {
731                name: "TestDoubler".into(),
732                kind: FilterKind::Stateless,
733                cacheable: true,
734                differentiable: true,
735                stream_mode: StreamMode::FixedState,
736                distribution: somatize_core::filter::Distribution::Local,
737                input_schema: None,
738                output_schema: None,
739            }
740        }
741    }
742
743    fn make_worker() -> Worker {
744        Worker::new(
745            "test_worker",
746            Capabilities {
747                cpu_cores: 4,
748                ram_bytes: 8_000_000_000,
749                gpus: vec![],
750                python_envs: vec![],
751                tags: vec!["cpu".into(), "test".into()],
752            },
753        )
754    }
755
756    #[test]
757    fn worker_registration() {
758        let worker = make_worker();
759        let msg = worker.registration_message();
760        if let WorkerToCoordinator::Register {
761            worker_id,
762            capabilities,
763        } = msg
764        {
765            assert_eq!(worker_id, "test_worker");
766            assert_eq!(capabilities.cpu_cores, 4);
767        } else {
768            panic!("wrong message type");
769        }
770    }
771
772    #[test]
773    fn worker_executes_plan_successfully() {
774        let mut worker = make_worker();
775        worker.register_filter("doubler", Box::new(TestDoubler));
776
777        let plan = SerializedPlan {
778            plan_id: "p_001".into(),
779            plan: ExecutionPlan::Execute {
780                node_id: "doubler".into(),
781            },
782            input: Some(crate::protocol::InputSource::Inline {
783                value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
784            }),
785            filters: vec![],
786            mode: ExecutionMode::default(),
787            metadata: serde_json::json!({}),
788        };
789
790        let result = worker.execute_plan(&plan);
791
792        if let PlanResult::Success {
793            output,
794            duration_ms,
795            ..
796        } = result
797        {
798            let value = match output {
799                OutputDelivery::Inline { value } => value,
800                _ => panic!("expected inline output"),
801            };
802            let (data, _) = value.as_tensor().unwrap();
803            assert_eq!(data, &[2.0, 4.0, 6.0]);
804            assert!(duration_ms < 1000);
805        } else {
806            panic!("expected success, got: {result:?}");
807        }
808    }
809
810    #[test]
811    fn worker_handles_missing_filter() {
812        let mut worker = make_worker();
813        // Don't register any filters
814
815        let plan = SerializedPlan {
816            plan_id: "p_002".into(),
817            plan: ExecutionPlan::Execute {
818                node_id: "nonexistent".into(),
819            },
820            input: None,
821            filters: vec![],
822            mode: ExecutionMode::default(),
823            metadata: serde_json::json!({}),
824        };
825
826        let result = worker.execute_plan(&plan);
827        assert!(matches!(result, PlanResult::Failed { .. }));
828    }
829
830    #[test]
831    fn worker_matches_target_by_id() {
832        let worker = make_worker();
833        assert!(
834            worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
835                "test_worker".into()
836            ))
837        );
838        assert!(
839            !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
840                "other".into()
841            ))
842        );
843    }
844
845    #[test]
846    fn worker_matches_target_by_tag() {
847        let worker = make_worker();
848        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
849        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
850        assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
851    }
852
853    #[test]
854    fn worker_executes_sequence() {
855        let mut worker = make_worker();
856        worker.register_filter("d1", Box::new(TestDoubler));
857        worker.register_filter("d2", Box::new(TestDoubler));
858
859        let plan = SerializedPlan {
860            plan_id: "p_003".into(),
861            plan: ExecutionPlan::Sequence(vec![
862                ExecutionPlan::Execute {
863                    node_id: "d1".into(),
864                },
865                ExecutionPlan::Execute {
866                    node_id: "d2".into(),
867                },
868            ]),
869            input: Some(crate::protocol::InputSource::Inline {
870                value: Value::tensor(vec![5.0], vec![1]),
871            }),
872            filters: vec![],
873            mode: ExecutionMode::default(),
874            metadata: serde_json::json!({}),
875        };
876
877        let result = worker.execute_plan(&plan);
878        if let PlanResult::Success { output, .. } = result {
879            let value = match output {
880                OutputDelivery::Inline { value } => value,
881                _ => panic!("expected inline output"),
882            };
883            let (data, _) = value.as_tensor().unwrap();
884            assert_eq!(data, &[20.0]); // 5 * 2 * 2
885        } else {
886            panic!("expected success");
887        }
888    }
889
890    #[test]
891    fn worker_emits_events() {
892        let mut worker = make_worker();
893        worker.register_filter("doubler", Box::new(TestDoubler));
894        let mut rx = worker.subscribe();
895
896        let plan = SerializedPlan {
897            plan_id: "p_004".into(),
898            plan: ExecutionPlan::Execute {
899                node_id: "doubler".into(),
900            },
901            input: Some(crate::protocol::InputSource::Inline {
902                value: Value::tensor(vec![1.0], vec![1]),
903            }),
904            filters: vec![],
905            mode: ExecutionMode::default(),
906            metadata: serde_json::json!({}),
907        };
908
909        worker.execute_plan(&plan);
910
911        let mut events = Vec::new();
912        while let Ok(e) = rx.try_recv() {
913            events.push(e);
914        }
915        assert!(
916            events
917                .iter()
918                .any(|e| matches!(e, Event::NodeStarted { .. }))
919        );
920        assert!(
921            events
922                .iter()
923                .any(|e| matches!(e, Event::NodeCompleted { .. }))
924        );
925    }
926}