1use crate::protocol::*;
4use somatize_core::cache::CacheStore;
5use somatize_core::event::Event;
6use somatize_core::filter::Filter;
7use somatize_core::store::{DataStore, LocalDataStore};
8use somatize_core::value::Value;
9use somatize_runtime::{EventBus, FilterLibrary, MemoryCache, Runner};
10use std::sync::Arc;
11use std::time::Instant;
12
13pub struct Worker {
15 pub id: WorkerId,
16 pub capabilities: Capabilities,
17 event_bus: Arc<EventBus>,
18 cache: Arc<dyn CacheStore>,
19 filters: FilterLibrary,
20 data_store: Option<Arc<dyn DataStore>>,
22 temp_store: Arc<LocalDataStore>,
24 env_manager: crate::env_manager::EnvManager,
26}
27
28impl Worker {
29 pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
30 let worker_id: String = id.into();
31 let temp_path = std::env::temp_dir().join(format!("soma-uploads-{worker_id}"));
32 let temp_store = LocalDataStore::new(temp_path);
33 let env_path = std::env::temp_dir().join(format!("soma-envs-{worker_id}"));
34 Self {
35 id: worker_id,
36 capabilities,
37 event_bus: Arc::new(EventBus::new(256)),
38 cache: Arc::new(MemoryCache::default()),
39 filters: FilterLibrary::new(),
40 data_store: None,
41 temp_store: Arc::new(temp_store),
42 env_manager: crate::env_manager::EnvManager::new(
43 env_path,
44 crate::env_manager::EnvType::Venv,
45 ),
46 }
47 }
48
49 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
51 self.cache = cache;
52 self
53 }
54
55 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
57 self.data_store = Some(store);
58 self
59 }
60
61 pub fn with_temp_dir(mut self, path: std::path::PathBuf) -> Self {
63 self.temp_store = Arc::new(LocalDataStore::new(path));
64 self
65 }
66
67 pub fn temp_store(&self) -> &Arc<LocalDataStore> {
69 &self.temp_store
70 }
71
72 pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
74 self.filters.register(node_id, filter);
75 }
76
77 pub fn get_filter(&self, node_id: &str) -> Option<Arc<dyn Filter>> {
79 self.filters.get(node_id)
80 }
81
82 pub fn get_filter_state(&self, node_id: &str) -> Arc<Value> {
84 self.filters
85 .get_state(node_id)
86 .unwrap_or_else(|| Arc::new(Value::Empty))
87 }
88
89 pub fn set_filter_state(&mut self, node_id: &str, state: Value) {
91 self.filters.set_state(node_id, state);
92 }
93
94 pub fn wrap_output(&self, output: Value) -> OutputDelivery {
96 let size = serde_json::to_vec(&output).map(|v| v.len()).unwrap_or(0);
97 if size >= somatize_core::store::INLINE_THRESHOLD_BYTES {
98 let key = somatize_core::cache::CacheKey::hash_data(
99 &serde_json::to_vec(&output).unwrap_or_default(),
100 );
101 if let Ok(data_ref) = self.temp_store.put(&key, &output) {
102 return OutputDelivery::Reference { data_ref };
103 }
104 }
105 OutputDelivery::Inline { value: output }
106 }
107
108 pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
110 self.event_bus.subscribe()
111 }
112
113 pub fn registration_message(&self) -> WorkerToCoordinator {
115 WorkerToCoordinator::Register {
116 worker_id: self.id.clone(),
117 capabilities: self.capabilities.clone(),
118 }
119 }
120
121 pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
131 let start = Instant::now();
132 let _span = tracing::info_span!(
133 "execute_plan",
134 plan_id = %plan.plan_id,
135 n_filters = plan.filters.len(),
136 mode = ?plan.mode,
137 )
138 .entered();
139
140 tracing::info!(
141 "Plan received: {} filters, mode={:?}",
142 plan.filters.len(),
143 plan.mode
144 );
145
146 let all_reqs: Vec<String> = plan
148 .filters
149 .iter()
150 .flat_map(|sf| sf.requirements.iter().cloned())
151 .collect::<std::collections::HashSet<_>>()
152 .into_iter()
153 .collect();
154
155 let python_path = if all_reqs.is_empty() {
157 "python3".to_string()
158 } else {
159 let reqs_str = all_reqs.join("\n");
160 match self.env_manager.ensure_env(&plan.plan_id, &reqs_str) {
161 Ok(path) => {
162 tracing::info!("Using venv for plan {}: {:?}", plan.plan_id, path);
163 path.to_string_lossy().to_string()
164 }
165 Err(e) => {
166 tracing::warn!("Failed to create venv, falling back to system python: {e}");
167 "python3".to_string()
168 }
169 }
170 };
171
172 let filter_specs: Vec<(String, Vec<u8>, bool)> = plan
177 .filters
178 .iter()
179 .map(|sf| (sf.node_id.clone(), sf.pickled_filter.clone(), sf.trainable))
180 .collect();
181
182 if !filter_specs.is_empty() {
183 let filter_names: Vec<&str> =
184 plan.filters.iter().map(|sf| sf.node_id.as_str()).collect();
185 tracing::info!(
186 python = %python_path,
187 filters = ?filter_names,
188 "Spawning Python process for {} filters",
189 filter_specs.len()
190 );
191
192 let mut proc = crate::python_process::PythonProcess::spawn(&python_path, &filter_specs)
193 .map_err(|e| {
194 tracing::error!("Failed to spawn Python process: {e}");
195 e
196 })
197 .expect("PythonProcess spawn failed");
198
199 for sf in &plan.filters {
201 if let Some(state) = &sf.state {
202 let size = match state {
203 Value::Bytes(b) => b.len(),
204 _ => 0,
205 };
206 tracing::info!(
207 node_id = %sf.node_id,
208 size_bytes = size,
209 "Loading trained state from previous epoch"
210 );
211 if let Err(e) = proc.set_state(&sf.node_id, state) {
212 tracing::warn!(
213 node_id = %sf.node_id,
214 error = %e,
215 "Failed to load state (will use fresh weights)"
216 );
217 }
218 }
219 }
220
221 let process = Arc::new(std::sync::Mutex::new(proc));
222
223 for sf in &plan.filters {
224 let filter = Box::new(crate::python_process::SubprocessFilter::new(
225 process.clone(),
226 sf.node_id.clone(),
227 sf.trainable,
228 ));
229 self.filters.register(&sf.node_id, filter);
230 if let Some(state) = &sf.state {
231 self.filters.set_state(&sf.node_id, state.clone());
232 }
233 }
234
235 tracing::info!("Filters registered, Python process ready");
236 }
237
238 let input_value = plan
240 .input
241 .as_ref()
242 .map(|src| src.resolve(self.data_store.as_deref(), &self.temp_store));
243
244 if let Some(InputSource::Reference { data_ref }) = &plan.input
247 && let Some(store) = self.data_store.clone()
248 && let Ok(meta) = store.meta(data_ref)
249 && meta.total_rows > 1024
250 {
251 return self.execute_streamed_from_store(plan, &store, data_ref, &meta, start);
252 }
253
254 let runner = somatize_runtime::LocalRunner;
256 let x = input_value.unwrap_or(Value::Empty);
257
258 let result = match &plan.mode {
259 ExecutionMode::Fit { y, batch_size } => {
260 if let Some(bs) = batch_size {
262 tracing::info!(batch_size = bs, "Using batched fit");
263 let node_ids = plan
264 .plan
265 .node_ids()
266 .iter()
267 .map(|s| s.to_string())
268 .collect::<Vec<_>>();
269 if let Some(filter) = self.filters.get(&node_ids[0]) {
270 if let Some(sf) = filter
271 .as_any()
272 .downcast_ref::<crate::python_process::SubprocessFilter>()
273 {
274 let result = sf
275 .process
276 .lock()
277 .map_err(|e| {
278 somatize_core::error::SomaError::Other(format!("mutex: {e}"))
279 })
280 .and_then(|mut proc| {
281 proc.batched_fit(&node_ids, &x, y.as_ref(), *bs)
282 });
283 match result {
284 Ok((output, states)) => {
285 for (id, state) in &states {
286 self.filters.set_state(id, state.clone());
287 }
288 Ok((output, states))
289 }
290 Err(e) => Err(e),
291 }
292 } else {
293 Err(somatize_core::error::SomaError::Other(
294 "batched_fit requires SubprocessFilter".into(),
295 ))
296 }
297 } else {
298 Err(somatize_core::error::SomaError::Other(
299 "no filters found".into(),
300 ))
301 }
302 } else {
303 runner
304 .fit(
305 &plan.plan,
306 &self.filters,
307 self.cache.as_ref(),
308 &self.event_bus,
309 &x,
310 y.as_ref(),
311 )
312 .map(|(output, all_outputs)| {
313 let mut trained_states = std::collections::HashMap::new();
315 for (key, value) in &all_outputs {
316 if let Some(node_id) = key.strip_prefix("__state_") {
317 self.filters.set_state(node_id, value.clone());
318 trained_states.insert(node_id.to_string(), value.clone());
319 }
320 }
321 (output, trained_states)
322 })
323 }
324 }
325 ExecutionMode::Forward => runner
326 .forward(
327 &plan.plan,
328 &self.filters,
329 self.cache.as_ref(),
330 &self.event_bus,
331 &x,
332 )
333 .map(|output| (output, std::collections::HashMap::new())),
334 };
335
336 let elapsed = start.elapsed().as_millis() as u64;
337 match result {
338 Ok((output, states)) => {
339 tracing::info!(
340 duration_ms = elapsed,
341 n_states = states.len(),
342 "Plan completed successfully"
343 );
344 PlanResult::Success {
345 output: self.wrap_output(output),
346 duration_ms: elapsed,
347 states,
348 }
349 }
350 Err(e) => {
351 tracing::error!(duration_ms = elapsed, error = %e, "Plan failed");
352 PlanResult::Failed {
353 error: e.to_string(),
354 duration_ms: elapsed,
355 }
356 }
357 }
358 }
359
360 fn execute_streamed_from_store(
363 &mut self,
364 plan: &SerializedPlan,
365 store: &Arc<dyn DataStore>,
366 data_ref: &somatize_core::store::DataRef,
367 meta: &somatize_core::store::StoreMeta,
368 start: Instant,
369 ) -> PlanResult {
370 use somatize_runtime::executors::stream::{FittedFilter, StreamExecutor};
371
372 let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
373 let fitted: Vec<FittedFilter> = node_ids
374 .iter()
375 .filter_map(|id| {
376 let filter = self.filters.get(id)?;
377 let state = self
378 .filters
379 .get_state(id)
380 .unwrap_or_else(|| Arc::new(Value::Empty));
381 Some(FittedFilter {
382 name: id.clone(),
383 filter,
384 state,
385 })
386 })
387 .collect();
388
389 let mut executor = StreamExecutor::new(fitted);
390 let chunk_size = 1024;
391 let run_id = format!("worker_stream_{}", plan.plan_id);
392
393 self.event_bus.emit(Event::RunStarted {
394 run_id: run_id.clone(),
395 plan_summary: somatize_core::event::PlanSummary {
396 total_nodes: node_ids.len(),
397 cached_nodes: 0,
398 parallel_branches: 0,
399 },
400 });
401
402 let mut last_output = Value::Empty;
403 let total = meta.total_rows;
404 let mut chunk_idx = 0;
405
406 for row_start in (0..total).step_by(chunk_size) {
407 let len = chunk_size.min(total - row_start);
408 let chunk = match store.get_rows(data_ref, row_start, len) {
409 Ok(c) => c,
410 Err(e) => {
411 return PlanResult::Failed {
412 error: format!("get_rows({row_start}..{}): {e}", row_start + len),
413 duration_ms: start.elapsed().as_millis() as u64,
414 };
415 }
416 };
417
418 match executor.process_chunk(chunk) {
419 Ok(Some(output)) => last_output = output,
420 Ok(None) => {} Err(e) => {
422 return PlanResult::Failed {
423 error: format!("stream chunk {chunk_idx}: {e}"),
424 duration_ms: start.elapsed().as_millis() as u64,
425 };
426 }
427 }
428 chunk_idx += 1;
429 }
430
431 match executor.flush() {
433 Ok(Some(output)) => last_output = output,
434 Ok(None) => {}
435 Err(e) => {
436 return PlanResult::Failed {
437 error: format!("stream flush: {e}"),
438 duration_ms: start.elapsed().as_millis() as u64,
439 };
440 }
441 }
442
443 tracing::info!(
444 "Streamed {chunk_idx} chunks ({total} rows) in {}ms",
445 start.elapsed().as_millis()
446 );
447
448 PlanResult::Success {
449 output: self.wrap_output(last_output),
450 duration_ms: start.elapsed().as_millis() as u64,
451 states: std::collections::HashMap::new(),
452 }
453 }
454
455 pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
457 match target {
458 somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
459 somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
460 }
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use somatize_compiler::ExecutionPlan;
468 use somatize_core::cache::CacheKey;
469 use somatize_core::error::Result as SomaResult;
470 use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
471 use somatize_core::value::Value;
472
473 struct TestDoubler;
474
475 impl Filter for TestDoubler {
476 fn config_hash(&self) -> CacheKey {
477 CacheKey::from_parts(&[b"TestDoubler"])
478 }
479 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
480 Ok(Value::Empty)
481 }
482 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
483 match x {
484 Value::Tensor { values, shape } => {
485 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
486 Ok(Value::tensor(doubled, shape.clone()))
487 }
488 _ => Ok(x.clone()),
489 }
490 }
491 fn meta(&self) -> FilterMeta {
492 FilterMeta {
493 name: "TestDoubler".into(),
494 kind: FilterKind::Stateless,
495 cacheable: true,
496 differentiable: true,
497 stream_mode: StreamMode::FixedState,
498 distribution: somatize_core::filter::Distribution::Local,
499 input_schema: None,
500 output_schema: None,
501 }
502 }
503
504 fn as_any(&self) -> &dyn std::any::Any {
505 self
506 }
507 }
508
509 fn make_worker() -> Worker {
510 Worker::new(
511 "test_worker",
512 Capabilities {
513 cpu_cores: 4,
514 ram_bytes: 8_000_000_000,
515 gpus: vec![],
516 python_envs: vec![],
517 tags: vec!["cpu".into(), "test".into()],
518 },
519 )
520 }
521
522 #[test]
523 fn worker_registration() {
524 let worker = make_worker();
525 let msg = worker.registration_message();
526 if let WorkerToCoordinator::Register {
527 worker_id,
528 capabilities,
529 } = msg
530 {
531 assert_eq!(worker_id, "test_worker");
532 assert_eq!(capabilities.cpu_cores, 4);
533 } else {
534 panic!("wrong message type");
535 }
536 }
537
538 #[test]
539 fn worker_executes_plan_successfully() {
540 let mut worker = make_worker();
541 worker.register_filter("doubler", Box::new(TestDoubler));
542
543 let plan = SerializedPlan {
544 plan_id: "p_001".into(),
545 plan: ExecutionPlan::Execute {
546 node_id: "doubler".into(),
547 },
548 input: Some(crate::protocol::InputSource::Inline {
549 value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
550 }),
551 filters: vec![],
552 mode: ExecutionMode::default(),
553 metadata: serde_json::json!({}),
554 };
555
556 let result = worker.execute_plan(&plan);
557
558 if let PlanResult::Success {
559 output,
560 duration_ms,
561 ..
562 } = result
563 {
564 let value = match output {
565 OutputDelivery::Inline { value } => value,
566 _ => panic!("expected inline output"),
567 };
568 let (data, _) = value.as_tensor().unwrap();
569 assert_eq!(data, &[2.0, 4.0, 6.0]);
570 assert!(duration_ms < 1000);
571 } else {
572 panic!("expected success, got: {result:?}");
573 }
574 }
575
576 #[test]
577 fn worker_handles_missing_filter() {
578 let mut worker = make_worker();
579 let plan = SerializedPlan {
582 plan_id: "p_002".into(),
583 plan: ExecutionPlan::Execute {
584 node_id: "nonexistent".into(),
585 },
586 input: None,
587 filters: vec![],
588 mode: ExecutionMode::default(),
589 metadata: serde_json::json!({}),
590 };
591
592 let result = worker.execute_plan(&plan);
593 assert!(matches!(result, PlanResult::Failed { .. }));
594 }
595
596 #[test]
597 fn worker_matches_target_by_id() {
598 let worker = make_worker();
599 assert!(
600 worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
601 "test_worker".into()
602 ))
603 );
604 assert!(
605 !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
606 "other".into()
607 ))
608 );
609 }
610
611 #[test]
612 fn worker_matches_target_by_tag() {
613 let worker = make_worker();
614 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
615 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
616 assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
617 }
618
619 #[test]
620 fn worker_executes_sequence() {
621 let mut worker = make_worker();
622 worker.register_filter("d1", Box::new(TestDoubler));
623 worker.register_filter("d2", Box::new(TestDoubler));
624
625 let plan = SerializedPlan {
626 plan_id: "p_003".into(),
627 plan: ExecutionPlan::Sequence(vec![
628 ExecutionPlan::Execute {
629 node_id: "d1".into(),
630 },
631 ExecutionPlan::Execute {
632 node_id: "d2".into(),
633 },
634 ]),
635 input: Some(crate::protocol::InputSource::Inline {
636 value: Value::tensor(vec![5.0], vec![1]),
637 }),
638 filters: vec![],
639 mode: ExecutionMode::default(),
640 metadata: serde_json::json!({}),
641 };
642
643 let result = worker.execute_plan(&plan);
644 if let PlanResult::Success { output, .. } = result {
645 let value = match output {
646 OutputDelivery::Inline { value } => value,
647 _ => panic!("expected inline output"),
648 };
649 let (data, _) = value.as_tensor().unwrap();
650 assert_eq!(data, &[20.0]); } else {
652 panic!("expected success");
653 }
654 }
655
656 #[test]
657 fn worker_emits_events() {
658 let mut worker = make_worker();
659 worker.register_filter("doubler", Box::new(TestDoubler));
660 let mut rx = worker.subscribe();
661
662 let plan = SerializedPlan {
663 plan_id: "p_004".into(),
664 plan: ExecutionPlan::Execute {
665 node_id: "doubler".into(),
666 },
667 input: Some(crate::protocol::InputSource::Inline {
668 value: Value::tensor(vec![1.0], vec![1]),
669 }),
670 filters: vec![],
671 mode: ExecutionMode::default(),
672 metadata: serde_json::json!({}),
673 };
674
675 worker.execute_plan(&plan);
676
677 let mut events = Vec::new();
678 while let Ok(e) = rx.try_recv() {
679 events.push(e);
680 }
681 assert!(
682 events
683 .iter()
684 .any(|e| matches!(e, Event::NodeStarted { .. }))
685 );
686 assert!(
687 events
688 .iter()
689 .any(|e| matches!(e, Event::NodeCompleted { .. }))
690 );
691 }
692}