Skip to main content

somatize_worker/
worker.rs

1//! Worker — receives and executes plans from a coordinator.
2
3use crate::protocol::*;
4use somatize_core::cache::CacheStore;
5use somatize_core::event::Event;
6use somatize_core::filter::Filter;
7use somatize_core::store::{DataStore, LocalDataStore};
8use somatize_core::value::Value;
9use somatize_runtime::{EventBus, FilterLibrary, MemoryCache, Runner};
10use std::sync::Arc;
11use std::time::Instant;
12
13/// Worker state: manages execution of plans received from a coordinator.
14pub struct Worker {
15    pub id: WorkerId,
16    pub capabilities: Capabilities,
17    event_bus: Arc<EventBus>,
18    cache: Arc<dyn CacheStore>,
19    filters: FilterLibrary,
20    /// Optional persistent DataStore (S3, Zarr, etc.) — configured by user.
21    data_store: Option<Arc<dyn DataStore>>,
22    /// Temporary local store for HTTP bulk uploads — auto-created, auto-cleaned.
23    temp_store: Arc<LocalDataStore>,
24    /// Environment manager for creating venvs with filter dependencies.
25    env_manager: crate::env_manager::EnvManager,
26}
27
28impl Worker {
29    pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
30        let worker_id: String = id.into();
31        let temp_path = std::env::temp_dir().join(format!("soma-uploads-{worker_id}"));
32        let temp_store = LocalDataStore::new(temp_path);
33        let env_path = std::env::temp_dir().join(format!("soma-envs-{worker_id}"));
34        Self {
35            id: worker_id,
36            capabilities,
37            event_bus: Arc::new(EventBus::new(256)),
38            cache: Arc::new(MemoryCache::default()),
39            filters: FilterLibrary::new(),
40            data_store: None,
41            temp_store: Arc::new(temp_store),
42            env_manager: crate::env_manager::EnvManager::new(
43                env_path,
44                crate::env_manager::EnvType::Venv,
45            ),
46        }
47    }
48
49    /// Set a custom cache store (e.g. tiered or shared).
50    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
51        self.cache = cache;
52        self
53    }
54
55    /// Set a persistent DataStore (S3, Zarr, etc.) for large data references.
56    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
57        self.data_store = Some(store);
58        self
59    }
60
61    /// Set a custom temp directory for HTTP bulk uploads.
62    pub fn with_temp_dir(mut self, path: std::path::PathBuf) -> Self {
63        self.temp_store = Arc::new(LocalDataStore::new(path));
64        self
65    }
66
67    /// Get the temp store (for HTTP upload endpoint).
68    pub fn temp_store(&self) -> &Arc<LocalDataStore> {
69        &self.temp_store
70    }
71
72    /// Register a filter that this worker can execute.
73    pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
74        self.filters.register(node_id, filter);
75    }
76
77    /// Get a filter by node_id (for stream executor construction).
78    pub fn get_filter(&self, node_id: &str) -> Option<Arc<dyn Filter>> {
79        self.filters.get(node_id)
80    }
81
82    /// Get trained state for a filter.
83    pub fn get_filter_state(&self, node_id: &str) -> Value {
84        self.filters
85            .get_state(node_id)
86            .cloned()
87            .unwrap_or(Value::Empty)
88    }
89
90    /// Set trained state for a filter.
91    pub fn set_filter_state(&mut self, node_id: &str, state: Value) {
92        self.filters.set_state(node_id, state);
93    }
94
95    /// Wrap output in the right delivery: inline for small, DataRef for large.
96    pub fn wrap_output(&self, output: Value) -> OutputDelivery {
97        let size = serde_json::to_vec(&output).map(|v| v.len()).unwrap_or(0);
98        if size >= somatize_core::store::INLINE_THRESHOLD_BYTES {
99            let key = somatize_core::cache::CacheKey::hash_data(
100                &serde_json::to_vec(&output).unwrap_or_default(),
101            );
102            if let Ok(data_ref) = self.temp_store.put(&key, &output) {
103                return OutputDelivery::Reference { data_ref };
104            }
105        }
106        OutputDelivery::Inline { value: output }
107    }
108
109    /// Subscribe to execution events.
110    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
111        self.event_bus.subscribe()
112    }
113
114    /// Build a registration message.
115    pub fn registration_message(&self) -> WorkerToCoordinator {
116        WorkerToCoordinator::Register {
117            worker_id: self.id.clone(),
118            capabilities: self.capabilities.clone(),
119        }
120    }
121
122    /// Execute a serialized plan.
123    ///
124    /// If the plan contains serialized filter definitions, they are registered
125    /// temporarily for this execution (alongside any pre-registered filters).
126    ///
127    /// In **Fit** mode: fits each filter (topological order), stores trained states,
128    /// then forwards to propagate outputs. Returns states so the client can cache them.
129    ///
130    /// In **Forward** mode: executes the compiled plan directly.
131    pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
132        let start = Instant::now();
133        let _span = tracing::info_span!(
134            "execute_plan",
135            plan_id = %plan.plan_id,
136            n_filters = plan.filters.len(),
137            mode = ?plan.mode,
138        )
139        .entered();
140
141        tracing::info!(
142            "Plan received: {} filters, mode={:?}",
143            plan.filters.len(),
144            plan.mode
145        );
146
147        // Collect all requirements from serialized filters
148        let all_reqs: Vec<String> = plan
149            .filters
150            .iter()
151            .flat_map(|sf| sf.requirements.iter().cloned())
152            .collect::<std::collections::HashSet<_>>()
153            .into_iter()
154            .collect();
155
156        // Create/reuse venv if there are pip requirements, otherwise use system python
157        let python_path = if all_reqs.is_empty() {
158            "python3".to_string()
159        } else {
160            let reqs_str = all_reqs.join("\n");
161            match self.env_manager.ensure_env(&plan.plan_id, &reqs_str) {
162                Ok(path) => {
163                    tracing::info!("Using venv for plan {}: {:?}", plan.plan_id, path);
164                    path.to_string_lossy().to_string()
165                }
166                Err(e) => {
167                    tracing::warn!("Failed to create venv, falling back to system python: {e}");
168                    "python3".to_string()
169                }
170            }
171        };
172
173        // No site-packages resolution needed — subprocess uses the venv python directly
174
175        // Spawn ONE Python subprocess for all filters in this plan.
176        // All filters share the same process (needed for Composite autograd).
177        let filter_specs: Vec<(String, Vec<u8>, bool)> = plan
178            .filters
179            .iter()
180            .map(|sf| (sf.node_id.clone(), sf.pickled_filter.clone(), sf.trainable))
181            .collect();
182
183        if !filter_specs.is_empty() {
184            let filter_names: Vec<&str> =
185                plan.filters.iter().map(|sf| sf.node_id.as_str()).collect();
186            tracing::info!(
187                python = %python_path,
188                filters = ?filter_names,
189                "Spawning Python process for {} filters",
190                filter_specs.len()
191            );
192
193            let mut proc = crate::python_process::PythonProcess::spawn(&python_path, &filter_specs)
194                .map_err(|e| {
195                    tracing::error!("Failed to spawn Python process: {e}");
196                    e
197                })
198                .expect("PythonProcess spawn failed");
199
200            // Load trained states from previous epochs (SET_STATE)
201            for sf in &plan.filters {
202                if let Some(state) = &sf.state {
203                    let size = match state {
204                        Value::Bytes(b) => b.len(),
205                        _ => 0,
206                    };
207                    tracing::info!(
208                        node_id = %sf.node_id,
209                        size_bytes = size,
210                        "Loading trained state from previous epoch"
211                    );
212                    if let Err(e) = proc.set_state(&sf.node_id, state) {
213                        tracing::warn!(
214                            node_id = %sf.node_id,
215                            error = %e,
216                            "Failed to load state (will use fresh weights)"
217                        );
218                    }
219                }
220            }
221
222            let process = Arc::new(std::sync::Mutex::new(proc));
223
224            for sf in &plan.filters {
225                let filter = Box::new(crate::python_process::SubprocessFilter::new(
226                    process.clone(),
227                    sf.node_id.clone(),
228                    sf.trainable,
229                ));
230                self.filters.register(&sf.node_id, filter);
231                if let Some(state) = &sf.state {
232                    self.filters.set_state(&sf.node_id, state.clone());
233                }
234            }
235
236            tracing::info!("Filters registered, Python process ready");
237        }
238
239        // Resolve input via InputSource::resolve()
240        let input_value = plan
241            .input
242            .as_ref()
243            .map(|src| src.resolve(self.data_store.as_deref(), &self.temp_store));
244
245        // DataStore-backed streaming: if input is a large DataRef and we have a store,
246        // read chunks via get_rows() and process with StreamExecutor (no full materialization).
247        if let Some(InputSource::Reference { data_ref }) = &plan.input
248            && let Some(store) = self.data_store.clone()
249            && let Ok(meta) = store.meta(data_ref)
250            && meta.total_rows > 1024
251        {
252            return self.execute_streamed_from_store(plan, &store, data_ref, &meta, start);
253        }
254
255        // Delegate to LocalRunner (same execution path as local)
256        let runner = somatize_runtime::LocalRunner;
257        let x = input_value.unwrap_or(Value::Empty);
258
259        let result = match &plan.mode {
260            ExecutionMode::Fit { y } => runner
261                .fit(
262                    &plan.plan,
263                    &self.filters,
264                    self.cache.as_ref(),
265                    &self.event_bus,
266                    &x,
267                    y.as_ref(),
268                )
269                .map(|(output, all_outputs)| {
270                    // Extract trained states (prefixed __state_) and store in library
271                    let mut trained_states = std::collections::HashMap::new();
272                    for (key, value) in &all_outputs {
273                        if let Some(node_id) = key.strip_prefix("__state_") {
274                            self.filters.set_state(node_id, value.clone());
275                            trained_states.insert(node_id.to_string(), value.clone());
276                        }
277                    }
278                    (output, trained_states)
279                }),
280            ExecutionMode::Forward => runner
281                .forward(
282                    &plan.plan,
283                    &self.filters,
284                    self.cache.as_ref(),
285                    &self.event_bus,
286                    &x,
287                )
288                .map(|output| (output, std::collections::HashMap::new())),
289        };
290
291        let elapsed = start.elapsed().as_millis() as u64;
292        match result {
293            Ok((output, states)) => {
294                tracing::info!(
295                    duration_ms = elapsed,
296                    n_states = states.len(),
297                    "Plan completed successfully"
298                );
299                PlanResult::Success {
300                    output: self.wrap_output(output),
301                    duration_ms: elapsed,
302                    states,
303                }
304            }
305            Err(e) => {
306                tracing::error!(duration_ms = elapsed, error = %e, "Plan failed");
307                PlanResult::Failed {
308                    error: e.to_string(),
309                    duration_ms: elapsed,
310                }
311            }
312        }
313    }
314
315    /// DataStore-backed streaming: read chunks via get_rows(), process with StreamExecutor.
316    /// Avoids loading the entire dataset into memory.
317    fn execute_streamed_from_store(
318        &mut self,
319        plan: &SerializedPlan,
320        store: &Arc<dyn DataStore>,
321        data_ref: &somatize_core::store::DataRef,
322        meta: &somatize_core::store::StoreMeta,
323        start: Instant,
324    ) -> PlanResult {
325        use somatize_runtime::executors::stream::{FittedFilter, StreamExecutor};
326
327        let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
328        let fitted: Vec<FittedFilter> = node_ids
329            .iter()
330            .filter_map(|id| {
331                let filter = self.filters.get(id)?;
332                let state = self.filters.get_state(id).cloned().unwrap_or(Value::Empty);
333                Some(FittedFilter {
334                    name: id.clone(),
335                    filter,
336                    state,
337                })
338            })
339            .collect();
340
341        let mut executor = StreamExecutor::new(fitted);
342        let chunk_size = 1024;
343        let run_id = format!("worker_stream_{}", plan.plan_id);
344
345        self.event_bus.emit(Event::RunStarted {
346            run_id: run_id.clone(),
347            plan_summary: somatize_core::event::PlanSummary {
348                total_nodes: node_ids.len(),
349                cached_nodes: 0,
350                parallel_branches: 0,
351            },
352        });
353
354        let mut last_output = Value::Empty;
355        let total = meta.total_rows;
356        let mut chunk_idx = 0;
357
358        for row_start in (0..total).step_by(chunk_size) {
359            let len = chunk_size.min(total - row_start);
360            let chunk = match store.get_rows(data_ref, row_start, len) {
361                Ok(c) => c,
362                Err(e) => {
363                    return PlanResult::Failed {
364                        error: format!("get_rows({row_start}..{}): {e}", row_start + len),
365                        duration_ms: start.elapsed().as_millis() as u64,
366                    };
367                }
368            };
369
370            match executor.process_chunk(chunk) {
371                Ok(Some(output)) => last_output = output,
372                Ok(None) => {} // Barrier — accumulating
373                Err(e) => {
374                    return PlanResult::Failed {
375                        error: format!("stream chunk {chunk_idx}: {e}"),
376                        duration_ms: start.elapsed().as_millis() as u64,
377                    };
378                }
379            }
380            chunk_idx += 1;
381        }
382
383        // Flush barrier filters
384        match executor.flush() {
385            Ok(Some(output)) => last_output = output,
386            Ok(None) => {}
387            Err(e) => {
388                return PlanResult::Failed {
389                    error: format!("stream flush: {e}"),
390                    duration_ms: start.elapsed().as_millis() as u64,
391                };
392            }
393        }
394
395        tracing::info!(
396            "Streamed {chunk_idx} chunks ({total} rows) in {}ms",
397            start.elapsed().as_millis()
398        );
399
400        PlanResult::Success {
401            output: self.wrap_output(last_output),
402            duration_ms: start.elapsed().as_millis() as u64,
403            states: std::collections::HashMap::new(),
404        }
405    }
406
407    /// Check if this worker matches a remote target.
408    pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
409        match target {
410            somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
411            somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
412        }
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use somatize_compiler::ExecutionPlan;
420    use somatize_core::cache::CacheKey;
421    use somatize_core::error::Result as SomaResult;
422    use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
423    use somatize_core::value::Value;
424
425    struct TestDoubler;
426
427    impl Filter for TestDoubler {
428        fn config_hash(&self) -> CacheKey {
429            CacheKey::from_parts(&[b"TestDoubler"])
430        }
431        fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
432            Ok(Value::Empty)
433        }
434        fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
435            match x {
436                Value::Tensor { values, shape } => {
437                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
438                    Ok(Value::tensor(doubled, shape.clone()))
439                }
440                _ => Ok(x.clone()),
441            }
442        }
443        fn meta(&self) -> FilterMeta {
444            FilterMeta {
445                name: "TestDoubler".into(),
446                kind: FilterKind::Stateless,
447                cacheable: true,
448                differentiable: true,
449                stream_mode: StreamMode::FixedState,
450                distribution: somatize_core::filter::Distribution::Local,
451                input_schema: None,
452                output_schema: None,
453            }
454        }
455
456        fn as_any(&self) -> &dyn std::any::Any {
457            self
458        }
459    }
460
461    fn make_worker() -> Worker {
462        Worker::new(
463            "test_worker",
464            Capabilities {
465                cpu_cores: 4,
466                ram_bytes: 8_000_000_000,
467                gpus: vec![],
468                python_envs: vec![],
469                tags: vec!["cpu".into(), "test".into()],
470            },
471        )
472    }
473
474    #[test]
475    fn worker_registration() {
476        let worker = make_worker();
477        let msg = worker.registration_message();
478        if let WorkerToCoordinator::Register {
479            worker_id,
480            capabilities,
481        } = msg
482        {
483            assert_eq!(worker_id, "test_worker");
484            assert_eq!(capabilities.cpu_cores, 4);
485        } else {
486            panic!("wrong message type");
487        }
488    }
489
490    #[test]
491    fn worker_executes_plan_successfully() {
492        let mut worker = make_worker();
493        worker.register_filter("doubler", Box::new(TestDoubler));
494
495        let plan = SerializedPlan {
496            plan_id: "p_001".into(),
497            plan: ExecutionPlan::Execute {
498                node_id: "doubler".into(),
499            },
500            input: Some(crate::protocol::InputSource::Inline {
501                value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
502            }),
503            filters: vec![],
504            mode: ExecutionMode::default(),
505            metadata: serde_json::json!({}),
506        };
507
508        let result = worker.execute_plan(&plan);
509
510        if let PlanResult::Success {
511            output,
512            duration_ms,
513            ..
514        } = result
515        {
516            let value = match output {
517                OutputDelivery::Inline { value } => value,
518                _ => panic!("expected inline output"),
519            };
520            let (data, _) = value.as_tensor().unwrap();
521            assert_eq!(data, &[2.0, 4.0, 6.0]);
522            assert!(duration_ms < 1000);
523        } else {
524            panic!("expected success, got: {result:?}");
525        }
526    }
527
528    #[test]
529    fn worker_handles_missing_filter() {
530        let mut worker = make_worker();
531        // Don't register any filters
532
533        let plan = SerializedPlan {
534            plan_id: "p_002".into(),
535            plan: ExecutionPlan::Execute {
536                node_id: "nonexistent".into(),
537            },
538            input: None,
539            filters: vec![],
540            mode: ExecutionMode::default(),
541            metadata: serde_json::json!({}),
542        };
543
544        let result = worker.execute_plan(&plan);
545        assert!(matches!(result, PlanResult::Failed { .. }));
546    }
547
548    #[test]
549    fn worker_matches_target_by_id() {
550        let worker = make_worker();
551        assert!(
552            worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
553                "test_worker".into()
554            ))
555        );
556        assert!(
557            !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
558                "other".into()
559            ))
560        );
561    }
562
563    #[test]
564    fn worker_matches_target_by_tag() {
565        let worker = make_worker();
566        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
567        assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
568        assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
569    }
570
571    #[test]
572    fn worker_executes_sequence() {
573        let mut worker = make_worker();
574        worker.register_filter("d1", Box::new(TestDoubler));
575        worker.register_filter("d2", Box::new(TestDoubler));
576
577        let plan = SerializedPlan {
578            plan_id: "p_003".into(),
579            plan: ExecutionPlan::Sequence(vec![
580                ExecutionPlan::Execute {
581                    node_id: "d1".into(),
582                },
583                ExecutionPlan::Execute {
584                    node_id: "d2".into(),
585                },
586            ]),
587            input: Some(crate::protocol::InputSource::Inline {
588                value: Value::tensor(vec![5.0], vec![1]),
589            }),
590            filters: vec![],
591            mode: ExecutionMode::default(),
592            metadata: serde_json::json!({}),
593        };
594
595        let result = worker.execute_plan(&plan);
596        if let PlanResult::Success { output, .. } = result {
597            let value = match output {
598                OutputDelivery::Inline { value } => value,
599                _ => panic!("expected inline output"),
600            };
601            let (data, _) = value.as_tensor().unwrap();
602            assert_eq!(data, &[20.0]); // 5 * 2 * 2
603        } else {
604            panic!("expected success");
605        }
606    }
607
608    #[test]
609    fn worker_emits_events() {
610        let mut worker = make_worker();
611        worker.register_filter("doubler", Box::new(TestDoubler));
612        let mut rx = worker.subscribe();
613
614        let plan = SerializedPlan {
615            plan_id: "p_004".into(),
616            plan: ExecutionPlan::Execute {
617                node_id: "doubler".into(),
618            },
619            input: Some(crate::protocol::InputSource::Inline {
620                value: Value::tensor(vec![1.0], vec![1]),
621            }),
622            filters: vec![],
623            mode: ExecutionMode::default(),
624            metadata: serde_json::json!({}),
625        };
626
627        worker.execute_plan(&plan);
628
629        let mut events = Vec::new();
630        while let Ok(e) = rx.try_recv() {
631            events.push(e);
632        }
633        assert!(
634            events
635                .iter()
636                .any(|e| matches!(e, Event::NodeStarted { .. }))
637        );
638        assert!(
639            events
640                .iter()
641                .any(|e| matches!(e, Event::NodeCompleted { .. }))
642        );
643    }
644}