1use crate::protocol::*;
4use somatize_core::cache::CacheStore;
5use somatize_core::event::Event;
6use somatize_core::filter::Filter;
7use somatize_core::store::{DataStore, LocalDataStore};
8use somatize_core::value::Value;
9use somatize_runtime::{EventBus, FilterLibrary, MemoryCache, Runner};
10use std::sync::Arc;
11use std::time::Instant;
12
13pub struct Worker {
15 pub id: WorkerId,
16 pub capabilities: Capabilities,
17 event_bus: Arc<EventBus>,
18 cache: Arc<dyn CacheStore>,
19 filters: FilterLibrary,
20 data_store: Option<Arc<dyn DataStore>>,
22 temp_store: Arc<LocalDataStore>,
24 env_manager: crate::env_manager::EnvManager,
26}
27
28impl Worker {
29 pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
30 let worker_id: String = id.into();
31 let temp_path = std::env::temp_dir().join(format!("soma-uploads-{worker_id}"));
32 let temp_store = LocalDataStore::new(temp_path);
33 let env_path = std::env::temp_dir().join(format!("soma-envs-{worker_id}"));
34 Self {
35 id: worker_id,
36 capabilities,
37 event_bus: Arc::new(EventBus::new(256)),
38 cache: Arc::new(MemoryCache::default()),
39 filters: FilterLibrary::new(),
40 data_store: None,
41 temp_store: Arc::new(temp_store),
42 env_manager: crate::env_manager::EnvManager::new(
43 env_path,
44 crate::env_manager::EnvType::Venv,
45 ),
46 }
47 }
48
49 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
51 self.cache = cache;
52 self
53 }
54
55 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
57 self.data_store = Some(store);
58 self
59 }
60
61 pub fn with_temp_dir(mut self, path: std::path::PathBuf) -> Self {
63 self.temp_store = Arc::new(LocalDataStore::new(path));
64 self
65 }
66
67 pub fn temp_store(&self) -> &Arc<LocalDataStore> {
69 &self.temp_store
70 }
71
72 pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
74 self.filters.register(node_id, filter);
75 }
76
77 pub fn get_filter(&self, node_id: &str) -> Option<Arc<dyn Filter>> {
79 self.filters.get(node_id)
80 }
81
82 pub fn get_filter_state(&self, node_id: &str) -> Value {
84 self.filters
85 .get_state(node_id)
86 .cloned()
87 .unwrap_or(Value::Empty)
88 }
89
90 pub fn set_filter_state(&mut self, node_id: &str, state: Value) {
92 self.filters.set_state(node_id, state);
93 }
94
95 pub fn wrap_output(&self, output: Value) -> OutputDelivery {
97 let size = serde_json::to_vec(&output).map(|v| v.len()).unwrap_or(0);
98 if size >= somatize_core::store::INLINE_THRESHOLD_BYTES {
99 let key = somatize_core::cache::CacheKey::hash_data(
100 &serde_json::to_vec(&output).unwrap_or_default(),
101 );
102 if let Ok(data_ref) = self.temp_store.put(&key, &output) {
103 return OutputDelivery::Reference { data_ref };
104 }
105 }
106 OutputDelivery::Inline { value: output }
107 }
108
109 pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
111 self.event_bus.subscribe()
112 }
113
114 pub fn registration_message(&self) -> WorkerToCoordinator {
116 WorkerToCoordinator::Register {
117 worker_id: self.id.clone(),
118 capabilities: self.capabilities.clone(),
119 }
120 }
121
122 pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
132 let start = Instant::now();
133 let _span = tracing::info_span!(
134 "execute_plan",
135 plan_id = %plan.plan_id,
136 n_filters = plan.filters.len(),
137 mode = ?plan.mode,
138 )
139 .entered();
140
141 tracing::info!(
142 "Plan received: {} filters, mode={:?}",
143 plan.filters.len(),
144 plan.mode
145 );
146
147 let all_reqs: Vec<String> = plan
149 .filters
150 .iter()
151 .flat_map(|sf| sf.requirements.iter().cloned())
152 .collect::<std::collections::HashSet<_>>()
153 .into_iter()
154 .collect();
155
156 let python_path = if all_reqs.is_empty() {
158 "python3".to_string()
159 } else {
160 let reqs_str = all_reqs.join("\n");
161 match self.env_manager.ensure_env(&plan.plan_id, &reqs_str) {
162 Ok(path) => {
163 tracing::info!("Using venv for plan {}: {:?}", plan.plan_id, path);
164 path.to_string_lossy().to_string()
165 }
166 Err(e) => {
167 tracing::warn!("Failed to create venv, falling back to system python: {e}");
168 "python3".to_string()
169 }
170 }
171 };
172
173 let filter_specs: Vec<(String, Vec<u8>, bool)> = plan
178 .filters
179 .iter()
180 .map(|sf| (sf.node_id.clone(), sf.pickled_filter.clone(), sf.trainable))
181 .collect();
182
183 if !filter_specs.is_empty() {
184 let filter_names: Vec<&str> =
185 plan.filters.iter().map(|sf| sf.node_id.as_str()).collect();
186 tracing::info!(
187 python = %python_path,
188 filters = ?filter_names,
189 "Spawning Python process for {} filters",
190 filter_specs.len()
191 );
192
193 let mut proc = crate::python_process::PythonProcess::spawn(&python_path, &filter_specs)
194 .map_err(|e| {
195 tracing::error!("Failed to spawn Python process: {e}");
196 e
197 })
198 .expect("PythonProcess spawn failed");
199
200 for sf in &plan.filters {
202 if let Some(state) = &sf.state {
203 let size = match state {
204 Value::Bytes(b) => b.len(),
205 _ => 0,
206 };
207 tracing::info!(
208 node_id = %sf.node_id,
209 size_bytes = size,
210 "Loading trained state from previous epoch"
211 );
212 if let Err(e) = proc.set_state(&sf.node_id, state) {
213 tracing::warn!(
214 node_id = %sf.node_id,
215 error = %e,
216 "Failed to load state (will use fresh weights)"
217 );
218 }
219 }
220 }
221
222 let process = Arc::new(std::sync::Mutex::new(proc));
223
224 for sf in &plan.filters {
225 let filter = Box::new(crate::python_process::SubprocessFilter::new(
226 process.clone(),
227 sf.node_id.clone(),
228 sf.trainable,
229 ));
230 self.filters.register(&sf.node_id, filter);
231 if let Some(state) = &sf.state {
232 self.filters.set_state(&sf.node_id, state.clone());
233 }
234 }
235
236 tracing::info!("Filters registered, Python process ready");
237 }
238
239 let input_value = plan
241 .input
242 .as_ref()
243 .map(|src| src.resolve(self.data_store.as_deref(), &self.temp_store));
244
245 if let Some(InputSource::Reference { data_ref }) = &plan.input
248 && let Some(store) = self.data_store.clone()
249 && let Ok(meta) = store.meta(data_ref)
250 && meta.total_rows > 1024
251 {
252 return self.execute_streamed_from_store(plan, &store, data_ref, &meta, start);
253 }
254
255 let runner = somatize_runtime::LocalRunner;
257 let x = input_value.unwrap_or(Value::Empty);
258
259 let result = match &plan.mode {
260 ExecutionMode::Fit { y } => runner
261 .fit(
262 &plan.plan,
263 &self.filters,
264 self.cache.as_ref(),
265 &self.event_bus,
266 &x,
267 y.as_ref(),
268 )
269 .map(|(output, all_outputs)| {
270 let mut trained_states = std::collections::HashMap::new();
272 for (key, value) in &all_outputs {
273 if let Some(node_id) = key.strip_prefix("__state_") {
274 self.filters.set_state(node_id, value.clone());
275 trained_states.insert(node_id.to_string(), value.clone());
276 }
277 }
278 (output, trained_states)
279 }),
280 ExecutionMode::Forward => runner
281 .forward(
282 &plan.plan,
283 &self.filters,
284 self.cache.as_ref(),
285 &self.event_bus,
286 &x,
287 )
288 .map(|output| (output, std::collections::HashMap::new())),
289 };
290
291 let elapsed = start.elapsed().as_millis() as u64;
292 match result {
293 Ok((output, states)) => {
294 tracing::info!(
295 duration_ms = elapsed,
296 n_states = states.len(),
297 "Plan completed successfully"
298 );
299 PlanResult::Success {
300 output: self.wrap_output(output),
301 duration_ms: elapsed,
302 states,
303 }
304 }
305 Err(e) => {
306 tracing::error!(duration_ms = elapsed, error = %e, "Plan failed");
307 PlanResult::Failed {
308 error: e.to_string(),
309 duration_ms: elapsed,
310 }
311 }
312 }
313 }
314
315 fn execute_streamed_from_store(
318 &mut self,
319 plan: &SerializedPlan,
320 store: &Arc<dyn DataStore>,
321 data_ref: &somatize_core::store::DataRef,
322 meta: &somatize_core::store::StoreMeta,
323 start: Instant,
324 ) -> PlanResult {
325 use somatize_runtime::executors::stream::{FittedFilter, StreamExecutor};
326
327 let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
328 let fitted: Vec<FittedFilter> = node_ids
329 .iter()
330 .filter_map(|id| {
331 let filter = self.filters.get(id)?;
332 let state = self.filters.get_state(id).cloned().unwrap_or(Value::Empty);
333 Some(FittedFilter {
334 name: id.clone(),
335 filter,
336 state,
337 })
338 })
339 .collect();
340
341 let mut executor = StreamExecutor::new(fitted);
342 let chunk_size = 1024;
343 let run_id = format!("worker_stream_{}", plan.plan_id);
344
345 self.event_bus.emit(Event::RunStarted {
346 run_id: run_id.clone(),
347 plan_summary: somatize_core::event::PlanSummary {
348 total_nodes: node_ids.len(),
349 cached_nodes: 0,
350 parallel_branches: 0,
351 },
352 });
353
354 let mut last_output = Value::Empty;
355 let total = meta.total_rows;
356 let mut chunk_idx = 0;
357
358 for row_start in (0..total).step_by(chunk_size) {
359 let len = chunk_size.min(total - row_start);
360 let chunk = match store.get_rows(data_ref, row_start, len) {
361 Ok(c) => c,
362 Err(e) => {
363 return PlanResult::Failed {
364 error: format!("get_rows({row_start}..{}): {e}", row_start + len),
365 duration_ms: start.elapsed().as_millis() as u64,
366 };
367 }
368 };
369
370 match executor.process_chunk(chunk) {
371 Ok(Some(output)) => last_output = output,
372 Ok(None) => {} Err(e) => {
374 return PlanResult::Failed {
375 error: format!("stream chunk {chunk_idx}: {e}"),
376 duration_ms: start.elapsed().as_millis() as u64,
377 };
378 }
379 }
380 chunk_idx += 1;
381 }
382
383 match executor.flush() {
385 Ok(Some(output)) => last_output = output,
386 Ok(None) => {}
387 Err(e) => {
388 return PlanResult::Failed {
389 error: format!("stream flush: {e}"),
390 duration_ms: start.elapsed().as_millis() as u64,
391 };
392 }
393 }
394
395 tracing::info!(
396 "Streamed {chunk_idx} chunks ({total} rows) in {}ms",
397 start.elapsed().as_millis()
398 );
399
400 PlanResult::Success {
401 output: self.wrap_output(last_output),
402 duration_ms: start.elapsed().as_millis() as u64,
403 states: std::collections::HashMap::new(),
404 }
405 }
406
407 pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
409 match target {
410 somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
411 somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
412 }
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use somatize_compiler::ExecutionPlan;
420 use somatize_core::cache::CacheKey;
421 use somatize_core::error::Result as SomaResult;
422 use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
423 use somatize_core::value::Value;
424
425 struct TestDoubler;
426
427 impl Filter for TestDoubler {
428 fn config_hash(&self) -> CacheKey {
429 CacheKey::from_parts(&[b"TestDoubler"])
430 }
431 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
432 Ok(Value::Empty)
433 }
434 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
435 match x {
436 Value::Tensor { values, shape } => {
437 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
438 Ok(Value::tensor(doubled, shape.clone()))
439 }
440 _ => Ok(x.clone()),
441 }
442 }
443 fn meta(&self) -> FilterMeta {
444 FilterMeta {
445 name: "TestDoubler".into(),
446 kind: FilterKind::Stateless,
447 cacheable: true,
448 differentiable: true,
449 stream_mode: StreamMode::FixedState,
450 distribution: somatize_core::filter::Distribution::Local,
451 input_schema: None,
452 output_schema: None,
453 }
454 }
455
456 fn as_any(&self) -> &dyn std::any::Any {
457 self
458 }
459 }
460
461 fn make_worker() -> Worker {
462 Worker::new(
463 "test_worker",
464 Capabilities {
465 cpu_cores: 4,
466 ram_bytes: 8_000_000_000,
467 gpus: vec![],
468 python_envs: vec![],
469 tags: vec!["cpu".into(), "test".into()],
470 },
471 )
472 }
473
474 #[test]
475 fn worker_registration() {
476 let worker = make_worker();
477 let msg = worker.registration_message();
478 if let WorkerToCoordinator::Register {
479 worker_id,
480 capabilities,
481 } = msg
482 {
483 assert_eq!(worker_id, "test_worker");
484 assert_eq!(capabilities.cpu_cores, 4);
485 } else {
486 panic!("wrong message type");
487 }
488 }
489
490 #[test]
491 fn worker_executes_plan_successfully() {
492 let mut worker = make_worker();
493 worker.register_filter("doubler", Box::new(TestDoubler));
494
495 let plan = SerializedPlan {
496 plan_id: "p_001".into(),
497 plan: ExecutionPlan::Execute {
498 node_id: "doubler".into(),
499 },
500 input: Some(crate::protocol::InputSource::Inline {
501 value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
502 }),
503 filters: vec![],
504 mode: ExecutionMode::default(),
505 metadata: serde_json::json!({}),
506 };
507
508 let result = worker.execute_plan(&plan);
509
510 if let PlanResult::Success {
511 output,
512 duration_ms,
513 ..
514 } = result
515 {
516 let value = match output {
517 OutputDelivery::Inline { value } => value,
518 _ => panic!("expected inline output"),
519 };
520 let (data, _) = value.as_tensor().unwrap();
521 assert_eq!(data, &[2.0, 4.0, 6.0]);
522 assert!(duration_ms < 1000);
523 } else {
524 panic!("expected success, got: {result:?}");
525 }
526 }
527
528 #[test]
529 fn worker_handles_missing_filter() {
530 let mut worker = make_worker();
531 let plan = SerializedPlan {
534 plan_id: "p_002".into(),
535 plan: ExecutionPlan::Execute {
536 node_id: "nonexistent".into(),
537 },
538 input: None,
539 filters: vec![],
540 mode: ExecutionMode::default(),
541 metadata: serde_json::json!({}),
542 };
543
544 let result = worker.execute_plan(&plan);
545 assert!(matches!(result, PlanResult::Failed { .. }));
546 }
547
548 #[test]
549 fn worker_matches_target_by_id() {
550 let worker = make_worker();
551 assert!(
552 worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
553 "test_worker".into()
554 ))
555 );
556 assert!(
557 !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
558 "other".into()
559 ))
560 );
561 }
562
563 #[test]
564 fn worker_matches_target_by_tag() {
565 let worker = make_worker();
566 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
567 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
568 assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
569 }
570
571 #[test]
572 fn worker_executes_sequence() {
573 let mut worker = make_worker();
574 worker.register_filter("d1", Box::new(TestDoubler));
575 worker.register_filter("d2", Box::new(TestDoubler));
576
577 let plan = SerializedPlan {
578 plan_id: "p_003".into(),
579 plan: ExecutionPlan::Sequence(vec![
580 ExecutionPlan::Execute {
581 node_id: "d1".into(),
582 },
583 ExecutionPlan::Execute {
584 node_id: "d2".into(),
585 },
586 ]),
587 input: Some(crate::protocol::InputSource::Inline {
588 value: Value::tensor(vec![5.0], vec![1]),
589 }),
590 filters: vec![],
591 mode: ExecutionMode::default(),
592 metadata: serde_json::json!({}),
593 };
594
595 let result = worker.execute_plan(&plan);
596 if let PlanResult::Success { output, .. } = result {
597 let value = match output {
598 OutputDelivery::Inline { value } => value,
599 _ => panic!("expected inline output"),
600 };
601 let (data, _) = value.as_tensor().unwrap();
602 assert_eq!(data, &[20.0]); } else {
604 panic!("expected success");
605 }
606 }
607
608 #[test]
609 fn worker_emits_events() {
610 let mut worker = make_worker();
611 worker.register_filter("doubler", Box::new(TestDoubler));
612 let mut rx = worker.subscribe();
613
614 let plan = SerializedPlan {
615 plan_id: "p_004".into(),
616 plan: ExecutionPlan::Execute {
617 node_id: "doubler".into(),
618 },
619 input: Some(crate::protocol::InputSource::Inline {
620 value: Value::tensor(vec![1.0], vec![1]),
621 }),
622 filters: vec![],
623 mode: ExecutionMode::default(),
624 metadata: serde_json::json!({}),
625 };
626
627 worker.execute_plan(&plan);
628
629 let mut events = Vec::new();
630 while let Ok(e) = rx.try_recv() {
631 events.push(e);
632 }
633 assert!(
634 events
635 .iter()
636 .any(|e| matches!(e, Event::NodeStarted { .. }))
637 );
638 assert!(
639 events
640 .iter()
641 .any(|e| matches!(e, Event::NodeCompleted { .. }))
642 );
643 }
644}