1use crate::env_manager::{EnvManager, EnvType};
7use crate::protocol::*;
8use crate::worker::Worker;
9use axum::Router;
10use axum::extract::DefaultBodyLimit;
11use axum::extract::ws::{Message, WebSocket};
12use axum::extract::{Query, State, WebSocketUpgrade};
13use axum::http::StatusCode;
14use axum::response::IntoResponse;
15use axum::routing::{get, post};
16use somatize_core::cache::CacheKey;
17use somatize_core::store::{DataStore, LocalDataStore};
18use somatize_core::value::Value;
19use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::{Arc, Mutex};
22use std::time::Instant;
23
24struct ServerState {
26 worker: Mutex<Worker>,
27 env_manager: EnvManager,
28 work_dir: PathBuf,
29 token: Option<String>,
31 temp_store: Arc<LocalDataStore>,
33 temp_uploads: Mutex<HashMap<CacheKey, Instant>>,
35 active_streams:
37 Mutex<HashMap<String, (somatize_runtime::executors::stream::StreamExecutor, Instant)>>,
38}
39
40pub fn worker_router(worker: Worker) -> Router {
42 worker_router_full(worker, "/tmp/soma-envs", "/tmp/soma-work", None)
43}
44
45pub fn worker_router_with_dirs(
47 worker: Worker,
48 env_dir: impl Into<PathBuf>,
49 work_dir: impl Into<PathBuf>,
50) -> Router {
51 worker_router_full(worker, env_dir, work_dir, None)
52}
53
54pub fn worker_router_authenticated(
56 worker: Worker,
57 env_dir: impl Into<PathBuf>,
58 work_dir: impl Into<PathBuf>,
59 token: impl Into<String>,
60) -> Router {
61 worker_router_full(worker, env_dir, work_dir, Some(token.into()))
62}
63
64fn worker_router_full(
65 worker: Worker,
66 env_dir: impl Into<PathBuf>,
67 work_dir: impl Into<PathBuf>,
68 token: Option<String>,
69) -> Router {
70 let work = work_dir.into();
71 std::fs::create_dir_all(&work).ok();
72 let temp_store = worker.temp_store().clone();
73 let state = Arc::new(ServerState {
74 worker: Mutex::new(worker),
75 env_manager: EnvManager::new(env_dir, EnvType::Venv),
76 work_dir: work,
77 token,
78 temp_store,
79 temp_uploads: Mutex::new(HashMap::new()),
80 active_streams: Mutex::new(HashMap::new()),
81 });
82 let cleanup_state = state.clone();
84 tokio::spawn(async move {
85 let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
86 loop {
87 interval.tick().await;
88 let cutoff = Instant::now() - std::time::Duration::from_secs(3600);
89 let expired: Vec<CacheKey> = {
90 let uploads = cleanup_state
91 .temp_uploads
92 .lock()
93 .unwrap_or_else(|e| e.into_inner());
94 uploads
95 .iter()
96 .filter(|(_, created)| **created < cutoff)
97 .map(|(k, _)| k.clone())
98 .collect()
99 };
100 if !expired.is_empty() {
101 let mut uploads = cleanup_state
102 .temp_uploads
103 .lock()
104 .unwrap_or_else(|e| e.into_inner());
105 for key in &expired {
106 let data_ref = somatize_core::store::DataRef::Cached {
107 cache_key: key.clone(),
108 };
109 let _ = cleanup_state.temp_store.remove(&data_ref);
110 uploads.remove(key);
111 }
112 tracing::info!("Cleaned up {} expired temp uploads", expired.len());
113 }
114 }
115 });
116
117 Router::new()
118 .route("/health", get(health))
119 .route("/info", get(info))
120 .route("/upload", post(upload_data))
121 .route("/download", get(download_data))
122 .route("/ws", get(ws_handler))
123 .layer(DefaultBodyLimit::disable()) .with_state(state)
125}
126
127pub async fn serve_worker(worker: Worker, addr: &str) -> Result<(), Box<dyn std::error::Error>> {
129 let listener = tokio::net::TcpListener::bind(addr).await?;
130 tracing::info!("Worker server listening on {addr}");
131 axum::serve(listener, worker_router(worker)).await?;
132 Ok(())
133}
134
135pub async fn serve_worker_authenticated(
137 worker: Worker,
138 addr: &str,
139 token: &str,
140) -> Result<(), Box<dyn std::error::Error>> {
141 let listener = tokio::net::TcpListener::bind(addr).await?;
142 tracing::info!("Worker server listening on {addr} (authenticated)");
143 let router = worker_router_authenticated(worker, "/tmp/soma-envs", "/tmp/soma-work", token);
144 axum::serve(listener, router).await?;
145 Ok(())
146}
147
148async fn health() -> &'static str {
149 "ok"
150}
151
152async fn info(State(state): State<Arc<ServerState>>) -> impl IntoResponse {
153 let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
154 let msg = worker.registration_message();
155 axum::Json(serde_json::to_value(msg).unwrap_or_default())
156}
157
158async fn upload_data(
163 Query(params): Query<WsParams>,
164 State(state): State<Arc<ServerState>>,
165 body: axum::body::Bytes,
166) -> Result<impl IntoResponse, StatusCode> {
167 if let Some(expected) = &state.token {
169 match ¶ms.token {
170 Some(provided) if provided == expected => {}
171 _ => return Err(StatusCode::UNAUTHORIZED),
172 }
173 }
174
175 let value: Value = rmp_serde::from_slice(&body)
177 .or_else(|_| serde_json::from_slice(&body))
178 .map_err(|_| StatusCode::BAD_REQUEST)?;
179
180 let key = CacheKey::hash_data(&body);
181 let data_ref = state
182 .temp_store
183 .put(&key, &value)
184 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
185
186 state
188 .temp_uploads
189 .lock()
190 .unwrap_or_else(|e| e.into_inner())
191 .insert(key, Instant::now());
192
193 tracing::info!("Uploaded {} bytes → {data_ref:?}", body.len());
194
195 Ok(axum::Json(
196 serde_json::to_value(&data_ref).unwrap_or_default(),
197 ))
198}
199
200#[derive(serde::Deserialize)]
202struct DownloadParams {
203 #[serde(rename = "ref")]
205 data_ref: String,
206 token: Option<String>,
207}
208
209async fn download_data(
214 Query(params): Query<DownloadParams>,
215 State(state): State<Arc<ServerState>>,
216) -> Result<impl IntoResponse, StatusCode> {
217 if let Some(expected) = &state.token {
219 match ¶ms.token {
220 Some(provided) if provided == expected => {}
221 _ => return Err(StatusCode::UNAUTHORIZED),
222 }
223 }
224
225 let data_ref: somatize_core::store::DataRef =
226 serde_json::from_str(¶ms.data_ref).map_err(|_| StatusCode::BAD_REQUEST)?;
227
228 let value = state
229 .temp_store
230 .get(&data_ref)
231 .map_err(|_| StatusCode::NOT_FOUND)?;
232
233 let bytes = serde_json::to_vec(&value).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
234
235 Ok((
236 [(axum::http::header::CONTENT_TYPE, "application/json")],
237 bytes,
238 ))
239}
240
241#[derive(serde::Deserialize, Default)]
243struct WsParams {
244 token: Option<String>,
245}
246
247async fn ws_handler(
248 ws: WebSocketUpgrade,
249 Query(params): Query<WsParams>,
250 State(state): State<Arc<ServerState>>,
251) -> Result<impl IntoResponse, StatusCode> {
252 if let Some(expected) = &state.token {
254 match ¶ms.token {
255 Some(provided) if provided == expected => {}
256 _ => return Err(StatusCode::UNAUTHORIZED),
257 }
258 }
259 Ok(ws
260 .max_message_size(usize::MAX) .max_frame_size(usize::MAX)
262 .on_upgrade(move |socket| handle_ws(socket, state)))
263}
264
265async fn handle_ws(mut socket: WebSocket, state: Arc<ServerState>) {
266 loop {
267 match socket.recv().await {
268 Some(Ok(Message::Text(text))) => {
269 let response = match serde_json::from_str::<CoordinatorToWorker>(&text) {
270 Ok(CoordinatorToWorker::AssignPlan { plan }) => {
271 let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
272 let plan_id = plan.plan_id.clone();
273 let worker_id = worker.id.clone();
274 let result = worker.execute_plan(&plan);
275 let msg = WorkerToCoordinator::PlanResult {
276 worker_id,
277 plan_id,
278 result,
279 };
280 serde_json::to_string(&msg).unwrap_or_default()
281 }
282 Ok(CoordinatorToWorker::StatusRequest) => {
283 let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
284 serde_json::to_string(&worker.registration_message()).unwrap_or_default()
285 }
286 Ok(CoordinatorToWorker::CancelPlan { .. }) => {
287 r#"{"status": "cancel_not_implemented"}"#.to_string()
288 }
289 Ok(CoordinatorToWorker::AssignPythonJob { job }) => {
290 let messages = execute_python_job_with_progress(&state, &job);
292 for msg in &messages[..messages.len().saturating_sub(1)] {
294 if socket
295 .send(Message::Text(msg.clone().into()))
296 .await
297 .is_err()
298 {
299 break;
300 }
301 }
302 messages.into_iter().last().unwrap_or_default()
304 }
305 Ok(CoordinatorToWorker::Ping) => r#"{"type":"Pong"}"#.to_string(),
306 Ok(CoordinatorToWorker::Registered { .. }) => continue,
307 Ok(CoordinatorToWorker::Shutdown { reason }) => {
308 tracing::info!("Shutdown requested: {reason}");
309 let _ = socket
310 .send(Message::Text(r#"{"type":"ShutdownAck"}"#.into()))
311 .await;
312 std::process::exit(0);
313 }
314 Ok(
315 CoordinatorToWorker::GetState { .. }
316 | CoordinatorToWorker::SetState { .. }
317 | CoordinatorToWorker::GetGradients { .. }
318 | CoordinatorToWorker::ApplyGradients { .. },
319 ) => {
320 r#"{"error":"get_state/set_state/get_gradients/apply_gradients not yet implemented for SubprocessFilter"}"#.to_string()
321 }
322 Err(e) => {
323 format!(r#"{{"error": "invalid message: {e}"}}"#)
324 }
325 };
326
327 if socket.send(Message::Text(response.into())).await.is_err() {
328 break;
329 }
330 }
331 Some(Ok(Message::Binary(bytes))) => {
332 if let Ok(stream_msg) = rmp_serde::from_slice::<StreamMessage>(&bytes) {
333 let reply = handle_stream_message(stream_msg, &state);
334 if let Some(reply_msg) = reply {
335 let reply_bytes = rmp_serde::to_vec(&reply_msg).unwrap_or_default();
336 if socket
337 .send(Message::Binary(reply_bytes.into()))
338 .await
339 .is_err()
340 {
341 break;
342 }
343 }
344 }
345 }
346 Some(Ok(Message::Close(_))) | None => break,
347 _ => {}
348 }
349 }
350}
351
352fn handle_stream_message(msg: StreamMessage, state: &Arc<ServerState>) -> Option<StreamMessage> {
354 use somatize_runtime::executors::stream::{FittedFilter, StreamExecutor};
355
356 match msg {
357 StreamMessage::StreamBegin {
358 stream_id, plan, ..
359 } => {
360 let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
362
363 let filter_specs: Vec<(String, Vec<u8>, bool)> = plan
365 .filters
366 .iter()
367 .map(|sf| (sf.node_id.clone(), sf.pickled_filter.clone(), sf.trainable))
368 .collect();
369
370 if !filter_specs.is_empty() {
371 let process = Arc::new(std::sync::Mutex::new(
372 crate::python_process::PythonProcess::spawn("python3", &filter_specs)
373 .expect("PythonProcess spawn failed"),
374 ));
375
376 for sf in &plan.filters {
377 let filter: Box<dyn somatize_core::filter::Filter> =
378 Box::new(crate::python_process::SubprocessFilter::new(
379 process.clone(),
380 sf.node_id.clone(),
381 sf.trainable,
382 ));
383 worker.register_filter(&sf.node_id, filter);
384 if let Some(s) = &sf.state {
385 worker.set_filter_state(&sf.node_id, s.clone());
386 }
387 }
388 }
389
390 let node_ids = plan.plan.node_ids();
392 let fitted: Vec<FittedFilter> = node_ids
393 .iter()
394 .filter_map(|id| {
395 let filter = worker.get_filter(id)?;
396 let filter_state = worker.get_filter_state(id);
397 Some(FittedFilter {
398 name: id.to_string(),
399 filter,
400 state: filter_state,
401 })
402 })
403 .collect();
404
405 let executor = StreamExecutor::new(fitted);
406 state
407 .active_streams
408 .lock()
409 .unwrap_or_else(|e| e.into_inner())
410 .insert(stream_id, (executor, Instant::now()));
411
412 None }
414 StreamMessage::ChunkData {
415 stream_id,
416 chunk_index,
417 value,
418 } => {
419 let mut streams = state
420 .active_streams
421 .lock()
422 .unwrap_or_else(|e| e.into_inner());
423 if let Some((executor, _)) = streams.get_mut(&stream_id) {
424 match executor.process_chunk(value) {
425 Ok(Some(result)) => Some(StreamMessage::ChunkResult {
426 stream_id,
427 chunk_index,
428 value: result,
429 }),
430 Ok(None) => None, Err(e) => Some(StreamMessage::StreamComplete {
432 stream_id,
433 result: PlanResult::Failed {
434 error: e.to_string(),
435 duration_ms: 0,
436 },
437 }),
438 }
439 } else {
440 Some(StreamMessage::StreamComplete {
441 stream_id,
442 result: PlanResult::Failed {
443 error: "unknown stream_id".to_string(),
444 duration_ms: 0,
445 },
446 })
447 }
448 }
449 StreamMessage::StreamEnd { stream_id } => {
450 let mut streams = state
451 .active_streams
452 .lock()
453 .unwrap_or_else(|e| e.into_inner());
454 if let Some((mut executor, start)) = streams.remove(&stream_id) {
455 let duration_ms = start.elapsed().as_millis() as u64;
456 let output = executor
458 .flush()
459 .unwrap_or(None)
460 .unwrap_or(somatize_core::value::Value::Empty);
461 Some(StreamMessage::StreamComplete {
462 stream_id,
463 result: PlanResult::Success {
464 output: OutputDelivery::Inline { value: output },
465 duration_ms,
466 states: std::collections::HashMap::new(),
467 },
468 })
469 } else {
470 None
471 }
472 }
473 _ => None,
474 }
475}
476
477fn execute_python_job_with_progress(state: &ServerState, job: &PythonPipelineJob) -> Vec<String> {
479 let start = Instant::now();
480 let mut messages = Vec::new();
481 let worker_id = {
482 let w = state.worker.lock().unwrap_or_else(|e| e.into_inner());
483 w.id.clone()
484 };
485
486 let progress = |wid: &str, jid: &str, phase: &str, step: u32, total: u32| -> String {
487 serde_json::to_string(&WorkerToCoordinator::JobProgress {
488 worker_id: wid.into(),
489 job_id: jid.into(),
490 phase: phase.into(),
491 step,
492 total,
493 metrics: serde_json::json!({}),
494 })
495 .unwrap_or_default()
496 };
497
498 messages.push(progress(&worker_id, &job.job_id, "environment", 1, 4));
500
501 let python = match state
502 .env_manager
503 .ensure_env(&job.pipeline_id, &job.requirements)
504 {
505 Ok(p) => p,
506 Err(e) => {
507 tracing::error!("Failed to create env for pipeline {}: {e}", job.pipeline_id);
508 let msg = WorkerToCoordinator::JobResult {
509 worker_id,
510 job_id: job.job_id.clone(),
511 success: false,
512 metrics: serde_json::json!({}),
513 output: format!("Environment setup failed: {e}"),
514 duration_ms: start.elapsed().as_millis() as u64,
515 };
516 messages.push(serde_json::to_string(&msg).unwrap_or_default());
517 return messages;
518 }
519 };
520
521 messages.push(progress(&worker_id, &job.job_id, "write_files", 2, 4));
523
524 let job_dir = state.work_dir.join(format!("job-{}", job.job_id));
525 if let Err(e) = std::fs::create_dir_all(&job_dir) {
526 let msg = WorkerToCoordinator::JobResult {
527 worker_id,
528 job_id: job.job_id.clone(),
529 success: false,
530 metrics: serde_json::json!({}),
531 output: format!("Failed to create work dir: {e}"),
532 duration_ms: start.elapsed().as_millis() as u64,
533 };
534 messages.push(serde_json::to_string(&msg).unwrap_or_default());
535 return messages;
536 }
537
538 for file in &job.files {
539 let file_path = job_dir.join(&file.path);
540 if let Some(parent) = file_path.parent() {
541 std::fs::create_dir_all(parent).ok();
542 }
543 if let Err(e) = std::fs::write(&file_path, &file.content) {
544 tracing::error!("Failed to write {}: {e}", file.path);
545 }
546 }
547
548 messages.push(progress(&worker_id, &job.job_id, "execute", 3, 4));
550
551 tracing::info!(
552 "Executing job {} with python: {}",
553 job.job_id,
554 python.display()
555 );
556
557 let output = std::process::Command::new(&python)
558 .arg(&job.entry_point)
559 .current_dir(&job_dir)
560 .env("PYTHONPATH", &job_dir)
561 .output();
562
563 let duration_ms = start.elapsed().as_millis() as u64;
564
565 let _ = std::fs::remove_dir_all(&job_dir);
567 messages.push(progress(&worker_id, &job.job_id, "collect_results", 4, 4));
568
569 let result_msg = match output {
570 Ok(out) => {
571 let stdout = String::from_utf8_lossy(&out.stdout).to_string();
572 let stderr = String::from_utf8_lossy(&out.stderr).to_string();
573 let success = out.status.success();
574
575 let metrics = stdout
576 .lines()
577 .rev()
578 .find_map(|line| serde_json::from_str::<serde_json::Value>(line).ok())
579 .unwrap_or(serde_json::json!({}));
580
581 if !success {
582 tracing::warn!(
583 "Job {} failed: {}",
584 job.job_id,
585 stderr.chars().take(200).collect::<String>()
586 );
587 }
588
589 WorkerToCoordinator::JobResult {
590 worker_id,
591 job_id: job.job_id.clone(),
592 success,
593 metrics,
594 output: if success {
595 stdout
596 } else {
597 format!("STDERR:\n{stderr}\nSTDOUT:\n{stdout}")
598 },
599 duration_ms,
600 }
601 }
602 Err(e) => WorkerToCoordinator::JobResult {
603 worker_id,
604 job_id: job.job_id.clone(),
605 success: false,
606 metrics: serde_json::json!({}),
607 output: format!("Failed to execute: {e}"),
608 duration_ms,
609 },
610 };
611 messages.push(serde_json::to_string(&result_msg).unwrap_or_default());
612 messages
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use crate::protocol::Capabilities;
619 fn make_worker() -> Worker {
620 Worker::new(
621 "test_worker",
622 Capabilities {
623 cpu_cores: 4,
624 ram_bytes: 8_000_000_000,
625 gpus: vec![],
626 python_envs: vec![],
627 tags: vec!["test".into()],
628 },
629 )
630 }
631
632 #[tokio::test]
633 async fn router_builds() {
634 let _router = worker_router(make_worker());
635 }
636
637 #[tokio::test]
638 async fn health_returns_ok() {
639 let resp = health().await;
640 assert_eq!(resp, "ok");
641 }
642
643 #[tokio::test]
644 async fn full_server_starts_and_stops() {
645 let worker = make_worker();
646 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
647 let addr = listener.local_addr().unwrap();
648
649 let server = tokio::spawn(async move {
650 axum::serve(listener, worker_router(worker)).await.unwrap();
651 });
652
653 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
654
655 let client = reqwest::Client::new();
656 let resp = client
657 .get(format!("http://{addr}/health"))
658 .send()
659 .await
660 .unwrap();
661 assert_eq!(resp.text().await.unwrap(), "ok");
662
663 let resp = client
664 .get(format!("http://{addr}/info"))
665 .send()
666 .await
667 .unwrap();
668 let json: serde_json::Value = resp.json().await.unwrap();
669 assert!(json.get("type").is_some() || json.get("worker_id").is_some());
670
671 server.abort();
672 }
673}