1use crate::env_manager::{EnvManager, EnvType};
7use crate::protocol::*;
8use crate::worker::Worker;
9use axum::Router;
10use axum::extract::ws::{Message, WebSocket};
11use axum::extract::{Query, State, WebSocketUpgrade};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::routing::get;
15use std::path::PathBuf;
16use std::sync::{Arc, Mutex};
17use std::time::Instant;
18
19struct ServerState {
21 worker: Mutex<Worker>,
22 env_manager: EnvManager,
23 work_dir: PathBuf,
24 token: Option<String>,
26}
27
28pub fn worker_router(worker: Worker) -> Router {
30 worker_router_full(worker, "/tmp/soma-envs", "/tmp/soma-work", None)
31}
32
33pub fn worker_router_with_dirs(
35 worker: Worker,
36 env_dir: impl Into<PathBuf>,
37 work_dir: impl Into<PathBuf>,
38) -> Router {
39 worker_router_full(worker, env_dir, work_dir, None)
40}
41
42pub fn worker_router_authenticated(
44 worker: Worker,
45 env_dir: impl Into<PathBuf>,
46 work_dir: impl Into<PathBuf>,
47 token: impl Into<String>,
48) -> Router {
49 worker_router_full(worker, env_dir, work_dir, Some(token.into()))
50}
51
52fn worker_router_full(
53 worker: Worker,
54 env_dir: impl Into<PathBuf>,
55 work_dir: impl Into<PathBuf>,
56 token: Option<String>,
57) -> Router {
58 let work = work_dir.into();
59 std::fs::create_dir_all(&work).ok();
60 let state = Arc::new(ServerState {
61 worker: Mutex::new(worker),
62 env_manager: EnvManager::new(env_dir, EnvType::Venv),
63 work_dir: work,
64 token,
65 });
66 Router::new()
67 .route("/health", get(health))
68 .route("/info", get(info))
69 .route("/ws", get(ws_handler))
70 .with_state(state)
71}
72
73pub async fn serve_worker(worker: Worker, addr: &str) -> Result<(), Box<dyn std::error::Error>> {
75 let listener = tokio::net::TcpListener::bind(addr).await?;
76 tracing::info!("Worker server listening on {addr}");
77 axum::serve(listener, worker_router(worker)).await?;
78 Ok(())
79}
80
81pub async fn serve_worker_authenticated(
83 worker: Worker,
84 addr: &str,
85 token: &str,
86) -> Result<(), Box<dyn std::error::Error>> {
87 let listener = tokio::net::TcpListener::bind(addr).await?;
88 tracing::info!("Worker server listening on {addr} (authenticated)");
89 let router = worker_router_authenticated(worker, "/tmp/soma-envs", "/tmp/soma-work", token);
90 axum::serve(listener, router).await?;
91 Ok(())
92}
93
94async fn health() -> &'static str {
95 "ok"
96}
97
98async fn info(State(state): State<Arc<ServerState>>) -> impl IntoResponse {
99 let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
100 let msg = worker.registration_message();
101 axum::Json(serde_json::to_value(msg).unwrap_or_default())
102}
103
104#[derive(serde::Deserialize, Default)]
106struct WsParams {
107 token: Option<String>,
108}
109
110async fn ws_handler(
111 ws: WebSocketUpgrade,
112 Query(params): Query<WsParams>,
113 State(state): State<Arc<ServerState>>,
114) -> Result<impl IntoResponse, StatusCode> {
115 if let Some(expected) = &state.token {
117 match ¶ms.token {
118 Some(provided) if provided == expected => {}
119 _ => return Err(StatusCode::UNAUTHORIZED),
120 }
121 }
122 Ok(ws.on_upgrade(move |socket| handle_ws(socket, state)))
123}
124
125async fn handle_ws(mut socket: WebSocket, state: Arc<ServerState>) {
126 loop {
127 match socket.recv().await {
128 Some(Ok(Message::Text(text))) => {
129 let response = match serde_json::from_str::<CoordinatorToWorker>(&text) {
130 Ok(CoordinatorToWorker::AssignPlan { plan }) => {
131 let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
132 let plan_id = plan.plan_id.clone();
133 let worker_id = worker.id.clone();
134 let result = worker.execute_plan(&plan);
135 let msg = WorkerToCoordinator::PlanResult {
136 worker_id,
137 plan_id,
138 result,
139 };
140 serde_json::to_string(&msg).unwrap_or_default()
141 }
142 Ok(CoordinatorToWorker::StatusRequest) => {
143 let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
144 serde_json::to_string(&worker.registration_message()).unwrap_or_default()
145 }
146 Ok(CoordinatorToWorker::CancelPlan { .. }) => {
147 r#"{"status": "cancel_not_implemented"}"#.to_string()
148 }
149 Ok(CoordinatorToWorker::AssignPythonJob { job }) => {
150 let messages = execute_python_job_with_progress(&state, &job);
152 for msg in &messages[..messages.len().saturating_sub(1)] {
154 if socket
155 .send(Message::Text(msg.clone().into()))
156 .await
157 .is_err()
158 {
159 break;
160 }
161 }
162 messages.into_iter().last().unwrap_or_default()
164 }
165 Ok(CoordinatorToWorker::Ping) => r#"{"type":"Pong"}"#.to_string(),
166 Ok(CoordinatorToWorker::Registered { .. }) => continue,
167 Err(e) => {
168 format!(r#"{{"error": "invalid message: {e}"}}"#)
169 }
170 };
171
172 if socket.send(Message::Text(response.into())).await.is_err() {
173 break;
174 }
175 }
176 Some(Ok(Message::Close(_))) | None => break,
177 _ => {}
178 }
179 }
180}
181
182fn execute_python_job_with_progress(state: &ServerState, job: &PythonPipelineJob) -> Vec<String> {
184 let start = Instant::now();
185 let mut messages = Vec::new();
186 let worker_id = {
187 let w = state.worker.lock().unwrap_or_else(|e| e.into_inner());
188 w.id.clone()
189 };
190
191 let progress = |wid: &str, jid: &str, phase: &str, step: u32, total: u32| -> String {
192 serde_json::to_string(&WorkerToCoordinator::JobProgress {
193 worker_id: wid.into(),
194 job_id: jid.into(),
195 phase: phase.into(),
196 step,
197 total,
198 metrics: serde_json::json!({}),
199 })
200 .unwrap_or_default()
201 };
202
203 messages.push(progress(&worker_id, &job.job_id, "environment", 1, 4));
205
206 let python = match state
207 .env_manager
208 .ensure_env(&job.pipeline_id, &job.requirements)
209 {
210 Ok(p) => p,
211 Err(e) => {
212 tracing::error!("Failed to create env for pipeline {}: {e}", job.pipeline_id);
213 let msg = WorkerToCoordinator::JobResult {
214 worker_id,
215 job_id: job.job_id.clone(),
216 success: false,
217 metrics: serde_json::json!({}),
218 output: format!("Environment setup failed: {e}"),
219 duration_ms: start.elapsed().as_millis() as u64,
220 };
221 messages.push(serde_json::to_string(&msg).unwrap_or_default());
222 return messages;
223 }
224 };
225
226 messages.push(progress(&worker_id, &job.job_id, "write_files", 2, 4));
228
229 let job_dir = state.work_dir.join(format!("job-{}", job.job_id));
230 if let Err(e) = std::fs::create_dir_all(&job_dir) {
231 let msg = WorkerToCoordinator::JobResult {
232 worker_id,
233 job_id: job.job_id.clone(),
234 success: false,
235 metrics: serde_json::json!({}),
236 output: format!("Failed to create work dir: {e}"),
237 duration_ms: start.elapsed().as_millis() as u64,
238 };
239 messages.push(serde_json::to_string(&msg).unwrap_or_default());
240 return messages;
241 }
242
243 for file in &job.files {
244 let file_path = job_dir.join(&file.path);
245 if let Some(parent) = file_path.parent() {
246 std::fs::create_dir_all(parent).ok();
247 }
248 if let Err(e) = std::fs::write(&file_path, &file.content) {
249 tracing::error!("Failed to write {}: {e}", file.path);
250 }
251 }
252
253 messages.push(progress(&worker_id, &job.job_id, "execute", 3, 4));
255
256 tracing::info!(
257 "Executing job {} with python: {}",
258 job.job_id,
259 python.display()
260 );
261
262 let output = std::process::Command::new(&python)
263 .arg(&job.entry_point)
264 .current_dir(&job_dir)
265 .env("PYTHONPATH", &job_dir)
266 .output();
267
268 let duration_ms = start.elapsed().as_millis() as u64;
269
270 let _ = std::fs::remove_dir_all(&job_dir);
272 messages.push(progress(&worker_id, &job.job_id, "collect_results", 4, 4));
273
274 let result_msg = match output {
275 Ok(out) => {
276 let stdout = String::from_utf8_lossy(&out.stdout).to_string();
277 let stderr = String::from_utf8_lossy(&out.stderr).to_string();
278 let success = out.status.success();
279
280 let metrics = stdout
281 .lines()
282 .rev()
283 .find_map(|line| serde_json::from_str::<serde_json::Value>(line).ok())
284 .unwrap_or(serde_json::json!({}));
285
286 if !success {
287 tracing::warn!(
288 "Job {} failed: {}",
289 job.job_id,
290 stderr.chars().take(200).collect::<String>()
291 );
292 }
293
294 WorkerToCoordinator::JobResult {
295 worker_id,
296 job_id: job.job_id.clone(),
297 success,
298 metrics,
299 output: if success {
300 stdout
301 } else {
302 format!("STDERR:\n{stderr}\nSTDOUT:\n{stdout}")
303 },
304 duration_ms,
305 }
306 }
307 Err(e) => WorkerToCoordinator::JobResult {
308 worker_id,
309 job_id: job.job_id.clone(),
310 success: false,
311 metrics: serde_json::json!({}),
312 output: format!("Failed to execute: {e}"),
313 duration_ms,
314 },
315 };
316 messages.push(serde_json::to_string(&result_msg).unwrap_or_default());
317 messages
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::protocol::Capabilities;
324 fn make_worker() -> Worker {
325 Worker::new(
326 "test_worker",
327 Capabilities {
328 cpu_cores: 4,
329 ram_bytes: 8_000_000_000,
330 gpus: vec![],
331 python_envs: vec![],
332 tags: vec!["test".into()],
333 },
334 )
335 }
336
337 #[test]
338 fn router_builds() {
339 let _router = worker_router(make_worker());
340 }
341
342 #[tokio::test]
343 async fn health_returns_ok() {
344 let resp = health().await;
345 assert_eq!(resp, "ok");
346 }
347
348 #[tokio::test]
349 async fn full_server_starts_and_stops() {
350 let worker = make_worker();
351 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
352 let addr = listener.local_addr().unwrap();
353
354 let server = tokio::spawn(async move {
355 axum::serve(listener, worker_router(worker)).await.unwrap();
356 });
357
358 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
359
360 let client = reqwest::Client::new();
361 let resp = client
362 .get(format!("http://{addr}/health"))
363 .send()
364 .await
365 .unwrap();
366 assert_eq!(resp.text().await.unwrap(), "ok");
367
368 let resp = client
369 .get(format!("http://{addr}/info"))
370 .send()
371 .await
372 .unwrap();
373 let json: serde_json::Value = resp.json().await.unwrap();
374 assert!(json.get("type").is_some() || json.get("worker_id").is_some());
375
376 server.abort();
377 }
378}