Skip to main content

somatize_worker/
worker.rs

1//! Worker — receives and executes plans from a coordinator.
2
3use crate::protocol::*;
4use somatize_core::cache::CacheStore;
5use somatize_core::event::Event;
6use somatize_core::filter::Filter;
7use somatize_core::store::{DataStore, LocalDataStore};
8use somatize_core::value::Value;
9use somatize_runtime::{EventBus, FilterLibrary, MemoryCache, Runner};
10use std::sync::Arc;
11use std::time::Instant;
12
13/// Worker state: manages execution of plans received from a coordinator.
14pub struct Worker {
15    pub id: WorkerId,
16    pub capabilities: Capabilities,
17    event_bus: Arc<EventBus>,
18    cache: Arc<dyn CacheStore>,
19    filters: FilterLibrary,
20    /// Optional persistent DataStore (S3, Zarr, etc.) — configured by user.
21    data_store: Option<Arc<dyn DataStore>>,
22    /// Temporary local store for HTTP bulk uploads — auto-created, auto-cleaned.
23    temp_store: Arc<LocalDataStore>,
24    /// Environment manager for creating venvs with filter dependencies.
25    env_manager: crate::env_manager::EnvManager,
26}
27
28impl Worker {
29    pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
30        let worker_id: String = id.into();
31        let temp_path = std::env::temp_dir().join(format!("soma-uploads-{worker_id}"));
32        let temp_store = LocalDataStore::new(temp_path);
33        let env_path = std::env::temp_dir().join(format!("soma-envs-{worker_id}"));
34        Self {
35            id: worker_id,
36            capabilities,
37            event_bus: Arc::new(EventBus::new(256)),
38            cache: Arc::new(MemoryCache::default()),
39            filters: FilterLibrary::new(),
40            data_store: None,
41            temp_store: Arc::new(temp_store),
42            env_manager: crate::env_manager::EnvManager::new(
43                env_path,
44                crate::env_manager::EnvType::Venv,
45            ),
46        }
47    }
48
49    /// Set a custom cache store (e.g. tiered or shared).
50    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
51        self.cache = cache;
52        self
53    }
54
55    /// Set a persistent DataStore (S3, Zarr, etc.) for large data references.
56    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
57        self.data_store = Some(store);
58        self
59    }
60
61    /// Set a custom temp directory for HTTP bulk uploads.
62    pub fn with_temp_dir(mut self, path: std::path::PathBuf) -> Self {
63        self.temp_store = Arc::new(LocalDataStore::new(path));
64        self
65    }
66
67    /// Get the temp store (for HTTP upload endpoint).
68    pub fn temp_store(&self) -> &Arc<LocalDataStore> {
69        &self.temp_store
70    }
71
72    /// Register a filter that this worker can execute.
73    pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
74        self.filters.register(node_id, filter);
75    }
76
77    /// Get a filter by node_id (for stream executor construction).
78    pub fn get_filter(&self, node_id: &str) -> Option<Arc<dyn Filter>> {
79        self.filters.get(node_id)
80    }
81
82    /// Get trained state for a filter.
83    pub fn get_filter_state(&self, node_id: &str) -> Arc<Value> {
84        self.filters
85            .get_state(node_id)
86            .unwrap_or_else(|| Arc::new(Value::Empty))
87    }
88
89    /// Set trained state for a filter.
90    pub fn set_filter_state(&mut self, node_id: &str, state: Value) {
91        self.filters.set_state(node_id, state);
92    }
93
94    /// Wrap output in the right delivery: inline for small, DataRef for large.
95    pub fn wrap_output(&self, output: Value) -> OutputDelivery {
96        let size = serde_json::to_vec(&output).map(|v| v.len()).unwrap_or(0);
97        if size >= somatize_core::store::INLINE_THRESHOLD_BYTES {
98            let key = somatize_core::cache::CacheKey::hash_data(
99                &serde_json::to_vec(&output).unwrap_or_default(),
100            );
101            if let Ok(data_ref) = self.temp_store.put(&key, &output) {
102                return OutputDelivery::Reference { data_ref };
103            }
104        }
105        OutputDelivery::Inline { value: output }
106    }
107
108    /// Subscribe to execution events.
109    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
110        self.event_bus.subscribe()
111    }
112
113    /// Build a registration message.
114    pub fn registration_message(&self) -> WorkerToCoordinator {
115        WorkerToCoordinator::Register {
116            worker_id: self.id.clone(),
117            capabilities: self.capabilities.clone(),
118        }
119    }
120
121    /// Execute a serialized plan.
122    ///
123    /// If the plan contains serialized filter definitions, they are registered
124    /// temporarily for this execution (alongside any pre-registered filters).
125    ///
126    /// In **Fit** mode: fits each filter (topological order), stores trained states,
127    /// then forwards to propagate outputs. Returns states so the client can cache them.
128    ///
129    /// In **Forward** mode: executes the compiled plan directly.
130    pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
131        let start = Instant::now();
132        let _span = tracing::info_span!(
133            "execute_plan",
134            plan_id = %plan.plan_id,
135            n_filters = plan.filters.len(),
136            mode = ?plan.mode,
137        )
138        .entered();
139
140        tracing::info!(
141            "Plan received: {} filters, mode={:?}",
142            plan.filters.len(),
143            plan.mode
144        );
145
146        // Collect all requirements from serialized filters
147        let all_reqs: Vec<String> = plan
148            .filters
149            .iter()
150            .flat_map(|sf| sf.requirements.iter().cloned())
151            .collect::<std::collections::HashSet<_>>()
152            .into_iter()
153            .collect();
154
155        // Create/reuse venv if there are pip requirements, otherwise use system python
156        let python_path = if all_reqs.is_empty() {
157            "python3".to_string()
158        } else {
159            let reqs_str = all_reqs.join("\n");
160            match self.env_manager.ensure_env(&plan.plan_id, &reqs_str) {
161                Ok(path) => {
162                    tracing::info!("Using venv for plan {}: {:?}", plan.plan_id, path);
163                    path.to_string_lossy().to_string()
164                }
165                Err(e) => {
166                    tracing::warn!("Failed to create venv, falling back to system python: {e}");
167                    "python3".to_string()
168                }
169            }
170        };
171
172        // No site-packages resolution needed — subprocess uses the venv python directly
173
174        // Spawn ONE Python subprocess for all filters in this plan.
175        // All filters share the same process (needed for Composite autograd).
176        let filter_specs: Vec<(String, Vec<u8>, bool)> = plan
177            .filters
178            .iter()
179            .map(|sf| (sf.node_id.clone(), sf.pickled_filter.clone(), sf.trainable))
180            .collect();
181
182        if !filter_specs.is_empty() {
183            let filter_names: Vec<&str> =
184                plan.filters.iter().map(|sf| sf.node_id.as_str()).collect();
185            tracing::info!(
186                python = %python_path,
187                filters = ?filter_names,
188                "Spawning Python process for {} filters",
189                filter_specs.len()
190            );
191
192            let mut proc = crate::python_process::PythonProcess::spawn(&python_path, &filter_specs)
193                .map_err(|e| {
194                    tracing::error!("Failed to spawn Python process: {e}");
195                    e
196                })
197                .expect("PythonProcess spawn failed");
198
199            // Load trained states from previous epochs (SET_STATE)
200            for sf in &plan.filters {
201                if let Some(state) = &sf.state {
202                    let size = match state {
203                        Value::Bytes(b) => b.len(),
204                        _ => 0,
205                    };
206                    tracing::info!(
207                        node_id = %sf.node_id,
208                        size_bytes = size,
209                        "Loading trained state from previous epoch"
210                    );
211                    if let Err(e) = proc.set_state(&sf.node_id, state) {
212                        tracing::warn!(
213                            node_id = %sf.node_id,
214                            error = %e,
215                            "Failed to load state (will use fresh weights)"
216                        );
217                    }
218                }
219            }
220
221            let process = Arc::new(std::sync::Mutex::new(proc));
222
223            for sf in &plan.filters {
224                let filter = Box::new(crate::python_process::SubprocessFilter::new(
225                    process.clone(),
226                    sf.node_id.clone(),
227                    sf.trainable,
228                ));
229                self.filters.register(&sf.node_id, filter);
230                if let Some(state) = &sf.state {
231                    self.filters.set_state(&sf.node_id, state.clone());
232                }
233            }
234
235            tracing::info!("Filters registered, Python process ready");
236        }
237
238        // Resolve input via InputSource::resolve()
239        let input_value = plan
240            .input
241            .as_ref()
242            .map(|src| src.resolve(self.data_store.as_deref(), &self.temp_store));
243
244        // DataStore-backed streaming: if input is a large DataRef and we have a store,
245        // read chunks via get_rows() and process with StreamExecutor (no full materialization).
246        if let Some(InputSource::Reference { data_ref }) = &plan.input
247            && let Some(store) = self.data_store.clone()
248            && let Ok(meta) = store.meta(data_ref)
249            && meta.total_rows > 1024
250        {
251            return self.execute_streamed_from_store(plan, &store, data_ref, &meta, start);
252        }
253
254        // Delegate to LocalRunner (same execution path as local)
255        let runner = somatize_runtime::LocalRunner;
256        let x = input_value.unwrap_or(Value::Empty);
257
258        let result = match &plan.mode {
259            ExecutionMode::Fit { y, batch_size } => {
260                // If batch_size is set, use BATCHED_FIT on the subprocess directly
261                if let Some(bs) = batch_size {
262                    tracing::info!(batch_size = bs, "Using batched fit");
263                    let node_ids = plan
264                        .plan
265                        .node_ids()
266                        .iter()
267                        .map(|s| s.to_string())
268                        .collect::<Vec<_>>();
269                    if let Some(filter) = self.filters.get(&node_ids[0]) {
270                        if let Some(sf) = filter
271                            .as_any()
272                            .downcast_ref::<crate::python_process::SubprocessFilter>()
273                        {
274                            let result = sf
275                                .process
276                                .lock()
277                                .map_err(|e| {
278                                    somatize_core::error::SomaError::Other(format!("mutex: {e}"))
279                                })
280                                .and_then(|mut proc| {
281                                    proc.batched_fit(&node_ids, &x, y.as_ref(), *bs)
282                                });
283                            match result {
284                                Ok((output, states)) => {
285                                    for (id, state) in &states {
286                                        self.filters.set_state(id, state.clone());
287                                    }
288                                    Ok((output, states))
289                                }
290                                Err(e) => Err(e),
291                            }
292                        } else {
293                            Err(somatize_core::error::SomaError::Other(
294                                "batched_fit requires SubprocessFilter".into(),
295                            ))
296                        }
297                    } else {
298                        Err(somatize_core::error::SomaError::Other(
299                            "no filters found".into(),
300                        ))
301                    }
302                } else {
303                    runner
304                        .fit(
305                            &plan.plan,
306                            &self.filters,
307                            self.cache.as_ref(),
308                            &self.event_bus,
309                            &x,
310                            y.as_ref(),
311                        )
312                        .map(|(output, all_outputs)| {
313                            // Extract trained states (prefixed __state_) and store in library
314                            let mut trained_states = std::collections::HashMap::new();
315                            for (key, value) in &all_outputs {
316                                if let Some(node_id) = key.strip_prefix("__state_") {
317                                    self.filters.set_state(node_id, value.clone());
318                                    trained_states.insert(node_id.to_string(), value.clone());
319                                }
320                            }
321                            (output, trained_states)
322                        })
323                }
324            }
325            ExecutionMode::Forward => runner
326                .forward(
327                    &plan.plan,
328                    &self.filters,
329                    self.cache.as_ref(),
330                    &self.event_bus,
331                    &x,
332                )
333                .map(|output| (output, std::collections::HashMap::new())),
334        };
335
336        let elapsed = start.elapsed().as_millis() as u64;
337        match result {
338            Ok((output, states)) => {
339                tracing::info!(
340                    duration_ms = elapsed,
341                    n_states = states.len(),
342                    "Plan completed successfully"
343                );
344                PlanResult::Success {
345                    output: self.wrap_output(output),
346                    duration_ms: elapsed,
347                    states,
348                }
349            }
350            Err(e) => {
351                tracing::error!(duration_ms = elapsed, error = %e, "Plan failed");
352                PlanResult::Failed {
353                    error: e.to_string(),
354                    duration_ms: elapsed,
355                }
356            }
357        }
358    }
359
360    /// DataStore-backed streaming: read chunks via get_rows(), process with StreamExecutor.
361    /// Avoids loading the entire dataset into memory.
362    fn execute_streamed_from_store(
363        &mut self,
364        plan: &SerializedPlan,
365        store: &Arc<dyn DataStore>,
366        data_ref: &somatize_core::store::DataRef,
367        meta: &somatize_core::store::StoreMeta,
368        start: Instant,
369    ) -> PlanResult {
370        use somatize_runtime::executors::stream::{FittedFilter, StreamExecutor};
371
372        let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
373        let fitted: Vec<FittedFilter> = node_ids
374            .iter()
375            .filter_map(|id| {
376                let filter = self.filters.get(id)?;
377                let state = self
378                    .filters
379                    .get_state(id)
380                    .unwrap_or_else(|| Arc::new(Value::Empty));
381                Some(FittedFilter {
382                    name: id.clone(),
383                    filter,
384                    state,
385                })
386            })
387            .collect();
388
389        let mut executor = StreamExecutor::new(fitted);
390        let chunk_size = 1024;
391        let run_id = format!("worker_stream_{}", plan.plan_id);
392
393        self.event_bus.emit(Event::RunStarted {
394            run_id: run_id.clone(),
395            plan_summary: somatize_core::event::PlanSummary {
396                total_nodes: node_ids.len(),
397                cached_nodes: 0,
398                parallel_branches: 0,
399            },
400        });
401
402        let mut last_output = Value::Empty;
403        let total = meta.total_rows;
404        let mut chunk_idx = 0;
405
406        for row_start in (0..total).step_by(chunk_size) {
407            let len = chunk_size.min(total - row_start);
408            let chunk = match store.get_rows(data_ref, row_start, len) {
409                Ok(c) => c,
410                Err(e) => {
411                    return PlanResult::Failed {
412                        error: format!("get_rows({row_start}..{}): {e}", row_start + len),
413                        duration_ms: start.elapsed().as_millis() as u64,
414                    };
415                }
416            };
417
418            match executor.process_chunk(chunk) {
419                Ok(Some(output)) => last_output = output,
420                Ok(None) => {} // Barrier — accumulating
421                Err(e) => {
422                    return PlanResult::Failed {
423                        error: format!("stream chunk {chunk_idx}: {e}"),
424                        duration_ms: start.elapsed().as_millis() as u64,
425                    };
426                }
427            }
428            chunk_idx += 1;
429        }
430
431        // Flush barrier filters
432        match executor.flush() {
433            Ok(Some(output)) => last_output = output,
434            Ok(None) => {}
435            Err(e) => {
436                return PlanResult::Failed {
437                    error: format!("stream flush: {e}"),
438                    duration_ms: start.elapsed().as_millis() as u64,
439                };
440            }
441        }
442
443        tracing::info!(
444            "Streamed {chunk_idx} chunks ({total} rows) in {}ms",
445            start.elapsed().as_millis()
446        );
447
448        PlanResult::Success {
449            output: self.wrap_output(last_output),
450            duration_ms: start.elapsed().as_millis() as u64,
451            states: std::collections::HashMap::new(),
452        }
453    }
454
455    /// Check if this worker matches a remote target.
456    pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
457        match target {
458            somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
459            somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
460        }
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use somatize_compiler::ExecutionPlan;
468    use somatize_core::cache::CacheKey;
469    use somatize_core::error::Result as SomaResult;
470    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
471    use somatize_core::value::Value;
472
473    struct TestDoubler;
474
475    impl Filter for TestDoubler {
476        fn config_hash(&self) -> CacheKey {
477            CacheKey::from_parts(&[b"TestDoubler"])
478        }
479        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
480            Ok(Value::Empty)
481        }
482        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
483            match x {
484                Value::Tensor { values, shape } => {
485                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
486                    Ok(Value::tensor(doubled, shape.clone()))
487                }
488                _ => Ok(x.clone()),
489            }
490        }
491        fn meta(&self) -> FilterMeta {
492            FilterMeta {
493                name: "TestDoubler".into(),
494                kind: FilterKind::Stateless,
495                cacheable: true,
496                differentiable: true,
497                stream_mode: StreamMode::FixedState,
498                distribution: somatize_core::filter::Distribution::Local,
499                input_schema: None,
500                output_schema: None,
501            }
502        }
503
504        fn as_any(&self) -> &dyn std::any::Any {
505            self
506        }
507    }
508
509    fn make_worker() -> Worker {
510        Worker::new(
511            "test_worker",
512            Capabilities {
513                cpu_cores: 4,
514                ram_bytes: 8_000_000_000,
515                gpus: vec![],
516                python_envs: vec![],
517                tags: vec!["cpu".into(), "test".into()],
518            },
519        )
520    }
521
522    #[test]
523    fn worker_registration() {
524        let worker = make_worker();
525        let msg = worker.registration_message();
526        if let WorkerToCoordinator::Register {
527            worker_id,
528            capabilities,
529        } = msg
530        {
531            assert_eq!(worker_id, "test_worker");
532            assert_eq!(capabilities.cpu_cores, 4);
533        } else {
534            panic!("wrong message type");
535        }
536    }
537
538    #[test]
539    fn worker_executes_plan_successfully() {
540        let mut worker = make_worker();
541        worker.register_filter("doubler", Box::new(TestDoubler));
542
543        let plan = SerializedPlan {
544            plan_id: "p_001".into(),
545            plan: ExecutionPlan::Execute {
546                node_id: "doubler".into(),
547            },
548            input: Some(crate::protocol::InputSource::Inline {
549                value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
550            }),
551            filters: vec![],
552            mode: ExecutionMode::default(),
553            metadata: serde_json::json!({}),
554        };
555
556        let result = worker.execute_plan(&plan);
557
558        if let PlanResult::Success {
559            output,
560            duration_ms,
561            ..
562        } = result
563        {
564            let value = match output {
565                OutputDelivery::Inline { value } => value,
566                _ => panic!("expected inline output"),
567            };
568            let (data, _) = value.as_tensor().unwrap();
569            assert_eq!(data, &[2.0, 4.0, 6.0]);
570            assert!(duration_ms < 1000);
571        } else {
572            panic!("expected success, got: {result:?}");
573        }
574    }
575
576    #[test]
577    fn worker_handles_missing_filter() {
578        let mut worker = make_worker();
579        // Don't register any filters
580
581        let plan = SerializedPlan {
582            plan_id: "p_002".into(),
583            plan: ExecutionPlan::Execute {
584                node_id: "nonexistent".into(),
585            },
586            input: None,
587            filters: vec![],
588            mode: ExecutionMode::default(),
589            metadata: serde_json::json!({}),
590        };
591
592        let result = worker.execute_plan(&plan);
593        assert!(matches!(result, PlanResult::Failed { .. }));
594    }
595
596    #[test]
597    fn worker_matches_target_by_id() {
598        let worker = make_worker();
599        assert!(
600            worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
601                "test_worker".into()
602            ))
603        );
604        assert!(
605            !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
606                "other".into()
607            ))
608        );
609    }
610
611    #[test]
612    fn worker_matches_target_by_tag() {
613        let worker = make_worker();
614        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
615        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
616        assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
617    }
618
619    #[test]
620    fn worker_executes_sequence() {
621        let mut worker = make_worker();
622        worker.register_filter("d1", Box::new(TestDoubler));
623        worker.register_filter("d2", Box::new(TestDoubler));
624
625        let plan = SerializedPlan {
626            plan_id: "p_003".into(),
627            plan: ExecutionPlan::Sequence(vec![
628                ExecutionPlan::Execute {
629                    node_id: "d1".into(),
630                },
631                ExecutionPlan::Execute {
632                    node_id: "d2".into(),
633                },
634            ]),
635            input: Some(crate::protocol::InputSource::Inline {
636                value: Value::tensor(vec![5.0], vec![1]),
637            }),
638            filters: vec![],
639            mode: ExecutionMode::default(),
640            metadata: serde_json::json!({}),
641        };
642
643        let result = worker.execute_plan(&plan);
644        if let PlanResult::Success { output, .. } = result {
645            let value = match output {
646                OutputDelivery::Inline { value } => value,
647                _ => panic!("expected inline output"),
648            };
649            let (data, _) = value.as_tensor().unwrap();
650            assert_eq!(data, &[20.0]); // 5 * 2 * 2
651        } else {
652            panic!("expected success");
653        }
654    }
655
656    #[test]
657    fn worker_emits_events() {
658        let mut worker = make_worker();
659        worker.register_filter("doubler", Box::new(TestDoubler));
660        let mut rx = worker.subscribe();
661
662        let plan = SerializedPlan {
663            plan_id: "p_004".into(),
664            plan: ExecutionPlan::Execute {
665                node_id: "doubler".into(),
666            },
667            input: Some(crate::protocol::InputSource::Inline {
668                value: Value::tensor(vec![1.0], vec![1]),
669            }),
670            filters: vec![],
671            mode: ExecutionMode::default(),
672            metadata: serde_json::json!({}),
673        };
674
675        worker.execute_plan(&plan);
676
677        let mut events = Vec::new();
678        while let Ok(e) = rx.try_recv() {
679            events.push(e);
680        }
681        assert!(
682            events
683                .iter()
684                .any(|e| matches!(e, Event::NodeStarted { .. }))
685        );
686        assert!(
687            events
688                .iter()
689                .any(|e| matches!(e, Event::NodeCompleted { .. }))
690        );
691    }
692}