1use crate::env_manager::{EnvManager, EnvType};
7use crate::protocol::*;
8use crate::worker::Worker;
9use axum::Router;
10use axum::extract::ws::{Message, WebSocket};
11use axum::extract::{Query, State, WebSocketUpgrade};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::routing::{get, post};
15use somatize_core::cache::CacheKey;
16use somatize_core::store::{DataStore, LocalDataStore};
17use somatize_core::value::Value;
18use std::collections::HashMap;
19use std::path::PathBuf;
20use std::sync::{Arc, Mutex};
21use std::time::Instant;
22
23struct ServerState {
25 worker: Mutex<Worker>,
26 env_manager: EnvManager,
27 work_dir: PathBuf,
28 token: Option<String>,
30 temp_store: Arc<LocalDataStore>,
32 temp_uploads: Mutex<HashMap<CacheKey, Instant>>,
34 active_streams: Mutex<HashMap<String, somatize_runtime::stream::StreamExecutor>>,
36}
37
38pub fn worker_router(worker: Worker) -> Router {
40 worker_router_full(worker, "/tmp/soma-envs", "/tmp/soma-work", None)
41}
42
43pub fn worker_router_with_dirs(
45 worker: Worker,
46 env_dir: impl Into<PathBuf>,
47 work_dir: impl Into<PathBuf>,
48) -> Router {
49 worker_router_full(worker, env_dir, work_dir, None)
50}
51
52pub fn worker_router_authenticated(
54 worker: Worker,
55 env_dir: impl Into<PathBuf>,
56 work_dir: impl Into<PathBuf>,
57 token: impl Into<String>,
58) -> Router {
59 worker_router_full(worker, env_dir, work_dir, Some(token.into()))
60}
61
62fn worker_router_full(
63 worker: Worker,
64 env_dir: impl Into<PathBuf>,
65 work_dir: impl Into<PathBuf>,
66 token: Option<String>,
67) -> Router {
68 let work = work_dir.into();
69 std::fs::create_dir_all(&work).ok();
70 let temp_store = worker.temp_store().clone();
71 let state = Arc::new(ServerState {
72 worker: Mutex::new(worker),
73 env_manager: EnvManager::new(env_dir, EnvType::Venv),
74 work_dir: work,
75 token,
76 temp_store,
77 temp_uploads: Mutex::new(HashMap::new()),
78 active_streams: Mutex::new(HashMap::new()),
79 });
80 let cleanup_state = state.clone();
82 tokio::spawn(async move {
83 let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
84 loop {
85 interval.tick().await;
86 let cutoff = Instant::now() - std::time::Duration::from_secs(3600);
87 let expired: Vec<CacheKey> = {
88 let uploads = cleanup_state
89 .temp_uploads
90 .lock()
91 .unwrap_or_else(|e| e.into_inner());
92 uploads
93 .iter()
94 .filter(|(_, created)| **created < cutoff)
95 .map(|(k, _)| k.clone())
96 .collect()
97 };
98 if !expired.is_empty() {
99 let mut uploads = cleanup_state
100 .temp_uploads
101 .lock()
102 .unwrap_or_else(|e| e.into_inner());
103 for key in &expired {
104 let data_ref = somatize_core::store::DataRef::Cached {
105 cache_key: key.clone(),
106 };
107 let _ = cleanup_state.temp_store.remove(&data_ref);
108 uploads.remove(key);
109 }
110 tracing::info!("Cleaned up {} expired temp uploads", expired.len());
111 }
112 }
113 });
114
115 Router::new()
116 .route("/health", get(health))
117 .route("/info", get(info))
118 .route("/upload", post(upload_data))
119 .route("/ws", get(ws_handler))
120 .with_state(state)
121}
122
123pub async fn serve_worker(worker: Worker, addr: &str) -> Result<(), Box<dyn std::error::Error>> {
125 let listener = tokio::net::TcpListener::bind(addr).await?;
126 tracing::info!("Worker server listening on {addr}");
127 axum::serve(listener, worker_router(worker)).await?;
128 Ok(())
129}
130
131pub async fn serve_worker_authenticated(
133 worker: Worker,
134 addr: &str,
135 token: &str,
136) -> Result<(), Box<dyn std::error::Error>> {
137 let listener = tokio::net::TcpListener::bind(addr).await?;
138 tracing::info!("Worker server listening on {addr} (authenticated)");
139 let router = worker_router_authenticated(worker, "/tmp/soma-envs", "/tmp/soma-work", token);
140 axum::serve(listener, router).await?;
141 Ok(())
142}
143
144async fn health() -> &'static str {
145 "ok"
146}
147
148async fn info(State(state): State<Arc<ServerState>>) -> impl IntoResponse {
149 let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
150 let msg = worker.registration_message();
151 axum::Json(serde_json::to_value(msg).unwrap_or_default())
152}
153
154async fn upload_data(
159 Query(params): Query<WsParams>,
160 State(state): State<Arc<ServerState>>,
161 body: axum::body::Bytes,
162) -> Result<impl IntoResponse, StatusCode> {
163 if let Some(expected) = &state.token {
165 match ¶ms.token {
166 Some(provided) if provided == expected => {}
167 _ => return Err(StatusCode::UNAUTHORIZED),
168 }
169 }
170
171 let value: Value = rmp_serde::from_slice(&body)
173 .or_else(|_| serde_json::from_slice(&body))
174 .map_err(|_| StatusCode::BAD_REQUEST)?;
175
176 let key = CacheKey::hash_data(&body);
177 let data_ref = state
178 .temp_store
179 .put(&key, &value)
180 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
181
182 state
184 .temp_uploads
185 .lock()
186 .unwrap_or_else(|e| e.into_inner())
187 .insert(key, Instant::now());
188
189 tracing::info!("Uploaded {} bytes → {data_ref:?}", body.len());
190
191 Ok(axum::Json(
192 serde_json::to_value(&data_ref).unwrap_or_default(),
193 ))
194}
195
196#[derive(serde::Deserialize, Default)]
198struct WsParams {
199 token: Option<String>,
200}
201
202async fn ws_handler(
203 ws: WebSocketUpgrade,
204 Query(params): Query<WsParams>,
205 State(state): State<Arc<ServerState>>,
206) -> Result<impl IntoResponse, StatusCode> {
207 if let Some(expected) = &state.token {
209 match ¶ms.token {
210 Some(provided) if provided == expected => {}
211 _ => return Err(StatusCode::UNAUTHORIZED),
212 }
213 }
214 Ok(ws.on_upgrade(move |socket| handle_ws(socket, state)))
215}
216
217async fn handle_ws(mut socket: WebSocket, state: Arc<ServerState>) {
218 loop {
219 match socket.recv().await {
220 Some(Ok(Message::Text(text))) => {
221 let response = match serde_json::from_str::<CoordinatorToWorker>(&text) {
222 Ok(CoordinatorToWorker::AssignPlan { plan }) => {
223 let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
224 let plan_id = plan.plan_id.clone();
225 let worker_id = worker.id.clone();
226 let result = worker.execute_plan(&plan);
227 let msg = WorkerToCoordinator::PlanResult {
228 worker_id,
229 plan_id,
230 result,
231 };
232 serde_json::to_string(&msg).unwrap_or_default()
233 }
234 Ok(CoordinatorToWorker::StatusRequest) => {
235 let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
236 serde_json::to_string(&worker.registration_message()).unwrap_or_default()
237 }
238 Ok(CoordinatorToWorker::CancelPlan { .. }) => {
239 r#"{"status": "cancel_not_implemented"}"#.to_string()
240 }
241 Ok(CoordinatorToWorker::AssignPythonJob { job }) => {
242 let messages = execute_python_job_with_progress(&state, &job);
244 for msg in &messages[..messages.len().saturating_sub(1)] {
246 if socket
247 .send(Message::Text(msg.clone().into()))
248 .await
249 .is_err()
250 {
251 break;
252 }
253 }
254 messages.into_iter().last().unwrap_or_default()
256 }
257 Ok(CoordinatorToWorker::Ping) => r#"{"type":"Pong"}"#.to_string(),
258 Ok(CoordinatorToWorker::Registered { .. }) => continue,
259 Ok(CoordinatorToWorker::Shutdown { reason }) => {
260 tracing::info!("Shutdown requested: {reason}");
261 let _ = socket
262 .send(Message::Text(r#"{"type":"ShutdownAck"}"#.into()))
263 .await;
264 std::process::exit(0);
265 }
266 Err(e) => {
267 format!(r#"{{"error": "invalid message: {e}"}}"#)
268 }
269 };
270
271 if socket.send(Message::Text(response.into())).await.is_err() {
272 break;
273 }
274 }
275 Some(Ok(Message::Binary(bytes))) => {
276 if let Ok(stream_msg) = rmp_serde::from_slice::<StreamMessage>(&bytes) {
277 let reply = handle_stream_message(stream_msg, &state);
278 if let Some(reply_msg) = reply {
279 let reply_bytes = rmp_serde::to_vec(&reply_msg).unwrap_or_default();
280 if socket
281 .send(Message::Binary(reply_bytes.into()))
282 .await
283 .is_err()
284 {
285 break;
286 }
287 }
288 }
289 }
290 Some(Ok(Message::Close(_))) | None => break,
291 _ => {}
292 }
293 }
294}
295
296fn handle_stream_message(msg: StreamMessage, state: &Arc<ServerState>) -> Option<StreamMessage> {
298 use somatize_runtime::stream::{FittedFilter, StreamExecutor};
299
300 match msg {
301 StreamMessage::StreamBegin {
302 stream_id, plan, ..
303 } => {
304 let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
306
307 for sf in &plan.filters {
309 let filter = Box::new(crate::worker::PickledFilterRunner {
310 pickled_bytes: sf.pickled_filter.clone(),
311 node_id: sf.node_id.clone(),
312 python_path: "python3".to_string(),
313 requirements: sf.requirements.clone(),
314 });
315 worker.register_filter(&sf.node_id, filter);
316 if let Some(s) = &sf.state {
317 worker.set_filter_state(&sf.node_id, s.clone());
318 }
319 }
320
321 let node_ids = plan.plan.node_ids();
323 let fitted: Vec<FittedFilter> = node_ids
324 .iter()
325 .filter_map(|id| {
326 let filter = worker.get_filter(id)?;
327 let filter_state = worker.get_filter_state(id);
328 Some(FittedFilter {
329 name: id.to_string(),
330 filter,
331 state: filter_state,
332 })
333 })
334 .collect();
335
336 let executor = StreamExecutor::new(fitted);
337 state
338 .active_streams
339 .lock()
340 .unwrap_or_else(|e| e.into_inner())
341 .insert(stream_id, executor);
342
343 None }
345 StreamMessage::ChunkData {
346 stream_id,
347 chunk_index,
348 value,
349 } => {
350 let mut streams = state
351 .active_streams
352 .lock()
353 .unwrap_or_else(|e| e.into_inner());
354 if let Some(executor) = streams.get_mut(&stream_id) {
355 match executor.process_chunk(value) {
356 Ok(Some(result)) => Some(StreamMessage::ChunkResult {
357 stream_id,
358 chunk_index,
359 value: result,
360 }),
361 Ok(None) => None, Err(e) => Some(StreamMessage::StreamComplete {
363 stream_id,
364 result: PlanResult::Failed {
365 error: e.to_string(),
366 duration_ms: 0,
367 },
368 }),
369 }
370 } else {
371 Some(StreamMessage::StreamComplete {
372 stream_id,
373 result: PlanResult::Failed {
374 error: "unknown stream_id".to_string(),
375 duration_ms: 0,
376 },
377 })
378 }
379 }
380 StreamMessage::StreamEnd { stream_id } => {
381 let mut streams = state
382 .active_streams
383 .lock()
384 .unwrap_or_else(|e| e.into_inner());
385 if let Some(mut executor) = streams.remove(&stream_id) {
386 let output = executor
388 .flush()
389 .unwrap_or(None)
390 .unwrap_or(somatize_core::value::Value::Empty);
391 Some(StreamMessage::StreamComplete {
392 stream_id,
393 result: PlanResult::Success {
394 output,
395 duration_ms: 0,
396 states: std::collections::HashMap::new(),
397 },
398 })
399 } else {
400 None
401 }
402 }
403 _ => None,
404 }
405}
406
407fn execute_python_job_with_progress(state: &ServerState, job: &PythonPipelineJob) -> Vec<String> {
409 let start = Instant::now();
410 let mut messages = Vec::new();
411 let worker_id = {
412 let w = state.worker.lock().unwrap_or_else(|e| e.into_inner());
413 w.id.clone()
414 };
415
416 let progress = |wid: &str, jid: &str, phase: &str, step: u32, total: u32| -> String {
417 serde_json::to_string(&WorkerToCoordinator::JobProgress {
418 worker_id: wid.into(),
419 job_id: jid.into(),
420 phase: phase.into(),
421 step,
422 total,
423 metrics: serde_json::json!({}),
424 })
425 .unwrap_or_default()
426 };
427
428 messages.push(progress(&worker_id, &job.job_id, "environment", 1, 4));
430
431 let python = match state
432 .env_manager
433 .ensure_env(&job.pipeline_id, &job.requirements)
434 {
435 Ok(p) => p,
436 Err(e) => {
437 tracing::error!("Failed to create env for pipeline {}: {e}", job.pipeline_id);
438 let msg = WorkerToCoordinator::JobResult {
439 worker_id,
440 job_id: job.job_id.clone(),
441 success: false,
442 metrics: serde_json::json!({}),
443 output: format!("Environment setup failed: {e}"),
444 duration_ms: start.elapsed().as_millis() as u64,
445 };
446 messages.push(serde_json::to_string(&msg).unwrap_or_default());
447 return messages;
448 }
449 };
450
451 messages.push(progress(&worker_id, &job.job_id, "write_files", 2, 4));
453
454 let job_dir = state.work_dir.join(format!("job-{}", job.job_id));
455 if let Err(e) = std::fs::create_dir_all(&job_dir) {
456 let msg = WorkerToCoordinator::JobResult {
457 worker_id,
458 job_id: job.job_id.clone(),
459 success: false,
460 metrics: serde_json::json!({}),
461 output: format!("Failed to create work dir: {e}"),
462 duration_ms: start.elapsed().as_millis() as u64,
463 };
464 messages.push(serde_json::to_string(&msg).unwrap_or_default());
465 return messages;
466 }
467
468 for file in &job.files {
469 let file_path = job_dir.join(&file.path);
470 if let Some(parent) = file_path.parent() {
471 std::fs::create_dir_all(parent).ok();
472 }
473 if let Err(e) = std::fs::write(&file_path, &file.content) {
474 tracing::error!("Failed to write {}: {e}", file.path);
475 }
476 }
477
478 messages.push(progress(&worker_id, &job.job_id, "execute", 3, 4));
480
481 tracing::info!(
482 "Executing job {} with python: {}",
483 job.job_id,
484 python.display()
485 );
486
487 let output = std::process::Command::new(&python)
488 .arg(&job.entry_point)
489 .current_dir(&job_dir)
490 .env("PYTHONPATH", &job_dir)
491 .output();
492
493 let duration_ms = start.elapsed().as_millis() as u64;
494
495 let _ = std::fs::remove_dir_all(&job_dir);
497 messages.push(progress(&worker_id, &job.job_id, "collect_results", 4, 4));
498
499 let result_msg = match output {
500 Ok(out) => {
501 let stdout = String::from_utf8_lossy(&out.stdout).to_string();
502 let stderr = String::from_utf8_lossy(&out.stderr).to_string();
503 let success = out.status.success();
504
505 let metrics = stdout
506 .lines()
507 .rev()
508 .find_map(|line| serde_json::from_str::<serde_json::Value>(line).ok())
509 .unwrap_or(serde_json::json!({}));
510
511 if !success {
512 tracing::warn!(
513 "Job {} failed: {}",
514 job.job_id,
515 stderr.chars().take(200).collect::<String>()
516 );
517 }
518
519 WorkerToCoordinator::JobResult {
520 worker_id,
521 job_id: job.job_id.clone(),
522 success,
523 metrics,
524 output: if success {
525 stdout
526 } else {
527 format!("STDERR:\n{stderr}\nSTDOUT:\n{stdout}")
528 },
529 duration_ms,
530 }
531 }
532 Err(e) => WorkerToCoordinator::JobResult {
533 worker_id,
534 job_id: job.job_id.clone(),
535 success: false,
536 metrics: serde_json::json!({}),
537 output: format!("Failed to execute: {e}"),
538 duration_ms,
539 },
540 };
541 messages.push(serde_json::to_string(&result_msg).unwrap_or_default());
542 messages
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548 use crate::protocol::Capabilities;
549 fn make_worker() -> Worker {
550 Worker::new(
551 "test_worker",
552 Capabilities {
553 cpu_cores: 4,
554 ram_bytes: 8_000_000_000,
555 gpus: vec![],
556 python_envs: vec![],
557 tags: vec!["test".into()],
558 },
559 )
560 }
561
562 #[tokio::test]
563 async fn router_builds() {
564 let _router = worker_router(make_worker());
565 }
566
567 #[tokio::test]
568 async fn health_returns_ok() {
569 let resp = health().await;
570 assert_eq!(resp, "ok");
571 }
572
573 #[tokio::test]
574 async fn full_server_starts_and_stops() {
575 let worker = make_worker();
576 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
577 let addr = listener.local_addr().unwrap();
578
579 let server = tokio::spawn(async move {
580 axum::serve(listener, worker_router(worker)).await.unwrap();
581 });
582
583 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
584
585 let client = reqwest::Client::new();
586 let resp = client
587 .get(format!("http://{addr}/health"))
588 .send()
589 .await
590 .unwrap();
591 assert_eq!(resp.text().await.unwrap(), "ok");
592
593 let resp = client
594 .get(format!("http://{addr}/info"))
595 .send()
596 .await
597 .unwrap();
598 let json: serde_json::Value = resp.json().await.unwrap();
599 assert!(json.get("type").is_some() || json.get("worker_id").is_some());
600
601 server.abort();
602 }
603}