Skip to main content

somatize_worker/
server.rs

1//! Axum HTTP/WebSocket server for the worker process.
2//!
3//! Supports optional bearer token authentication on WebSocket connections.
4//! Set a token via [`worker_router_authenticated`] or the `--token` CLI flag.
5
6use crate::env_manager::{EnvManager, EnvType};
7use crate::protocol::*;
8use crate::worker::Worker;
9use axum::Router;
10use axum::extract::ws::{Message, WebSocket};
11use axum::extract::{Query, State, WebSocketUpgrade};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::routing::get;
15use std::path::PathBuf;
16use std::sync::{Arc, Mutex};
17use std::time::Instant;
18
19/// Shared state for the worker HTTP/WebSocket server.
20struct ServerState {
21    worker: Mutex<Worker>,
22    env_manager: EnvManager,
23    work_dir: PathBuf,
24    /// Optional bearer token for authentication.
25    token: Option<String>,
26}
27
28/// Build a worker server router (no authentication).
29pub fn worker_router(worker: Worker) -> Router {
30    worker_router_full(worker, "/tmp/soma-envs", "/tmp/soma-work", None)
31}
32
33/// Build a worker server router with custom directories.
34pub fn worker_router_with_dirs(
35    worker: Worker,
36    env_dir: impl Into<PathBuf>,
37    work_dir: impl Into<PathBuf>,
38) -> Router {
39    worker_router_full(worker, env_dir, work_dir, None)
40}
41
42/// Build a worker server router with authentication.
43pub fn worker_router_authenticated(
44    worker: Worker,
45    env_dir: impl Into<PathBuf>,
46    work_dir: impl Into<PathBuf>,
47    token: impl Into<String>,
48) -> Router {
49    worker_router_full(worker, env_dir, work_dir, Some(token.into()))
50}
51
52fn worker_router_full(
53    worker: Worker,
54    env_dir: impl Into<PathBuf>,
55    work_dir: impl Into<PathBuf>,
56    token: Option<String>,
57) -> Router {
58    let work = work_dir.into();
59    std::fs::create_dir_all(&work).ok();
60    let state = Arc::new(ServerState {
61        worker: Mutex::new(worker),
62        env_manager: EnvManager::new(env_dir, EnvType::Venv),
63        work_dir: work,
64        token,
65    });
66    Router::new()
67        .route("/health", get(health))
68        .route("/info", get(info))
69        .route("/ws", get(ws_handler))
70        .with_state(state)
71}
72
73/// Start a worker server on the given address.
74pub async fn serve_worker(worker: Worker, addr: &str) -> Result<(), Box<dyn std::error::Error>> {
75    let listener = tokio::net::TcpListener::bind(addr).await?;
76    tracing::info!("Worker server listening on {addr}");
77    axum::serve(listener, worker_router(worker)).await?;
78    Ok(())
79}
80
81/// Start a worker server with authentication.
82pub async fn serve_worker_authenticated(
83    worker: Worker,
84    addr: &str,
85    token: &str,
86) -> Result<(), Box<dyn std::error::Error>> {
87    let listener = tokio::net::TcpListener::bind(addr).await?;
88    tracing::info!("Worker server listening on {addr} (authenticated)");
89    let router = worker_router_authenticated(worker, "/tmp/soma-envs", "/tmp/soma-work", token);
90    axum::serve(listener, router).await?;
91    Ok(())
92}
93
94async fn health() -> &'static str {
95    "ok"
96}
97
98async fn info(State(state): State<Arc<ServerState>>) -> impl IntoResponse {
99    let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
100    let msg = worker.registration_message();
101    axum::Json(serde_json::to_value(msg).unwrap_or_default())
102}
103
104/// Query params for WebSocket authentication.
105#[derive(serde::Deserialize, Default)]
106struct WsParams {
107    token: Option<String>,
108}
109
110async fn ws_handler(
111    ws: WebSocketUpgrade,
112    Query(params): Query<WsParams>,
113    State(state): State<Arc<ServerState>>,
114) -> Result<impl IntoResponse, StatusCode> {
115    // Validate token if server requires one
116    if let Some(expected) = &state.token {
117        match &params.token {
118            Some(provided) if provided == expected => {}
119            _ => return Err(StatusCode::UNAUTHORIZED),
120        }
121    }
122    Ok(ws.on_upgrade(move |socket| handle_ws(socket, state)))
123}
124
125async fn handle_ws(mut socket: WebSocket, state: Arc<ServerState>) {
126    loop {
127        match socket.recv().await {
128            Some(Ok(Message::Text(text))) => {
129                let response = match serde_json::from_str::<CoordinatorToWorker>(&text) {
130                    Ok(CoordinatorToWorker::AssignPlan { plan }) => {
131                        let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
132                        let plan_id = plan.plan_id.clone();
133                        let worker_id = worker.id.clone();
134                        let result = worker.execute_plan(&plan);
135                        let msg = WorkerToCoordinator::PlanResult {
136                            worker_id,
137                            plan_id,
138                            result,
139                        };
140                        serde_json::to_string(&msg).unwrap_or_default()
141                    }
142                    Ok(CoordinatorToWorker::StatusRequest) => {
143                        let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
144                        serde_json::to_string(&worker.registration_message()).unwrap_or_default()
145                    }
146                    Ok(CoordinatorToWorker::CancelPlan { .. }) => {
147                        r#"{"status": "cancel_not_implemented"}"#.to_string()
148                    }
149                    Ok(CoordinatorToWorker::AssignPythonJob { job }) => {
150                        // Send progress messages during execution
151                        let messages = execute_python_job_with_progress(&state, &job);
152                        // Send all but the last as intermediate messages
153                        for msg in &messages[..messages.len().saturating_sub(1)] {
154                            if socket
155                                .send(Message::Text(msg.clone().into()))
156                                .await
157                                .is_err()
158                            {
159                                break;
160                            }
161                        }
162                        // Return the last message (result) through the normal path
163                        messages.into_iter().last().unwrap_or_default()
164                    }
165                    Ok(CoordinatorToWorker::Ping) => r#"{"type":"Pong"}"#.to_string(),
166                    Ok(CoordinatorToWorker::Registered { .. }) => continue,
167                    Err(e) => {
168                        format!(r#"{{"error": "invalid message: {e}"}}"#)
169                    }
170                };
171
172                if socket.send(Message::Text(response.into())).await.is_err() {
173                    break;
174                }
175            }
176            Some(Ok(Message::Close(_))) | None => break,
177            _ => {}
178        }
179    }
180}
181
182/// Execute a Python pipeline job with progress reporting.
183fn execute_python_job_with_progress(state: &ServerState, job: &PythonPipelineJob) -> Vec<String> {
184    let start = Instant::now();
185    let mut messages = Vec::new();
186    let worker_id = {
187        let w = state.worker.lock().unwrap_or_else(|e| e.into_inner());
188        w.id.clone()
189    };
190
191    let progress = |wid: &str, jid: &str, phase: &str, step: u32, total: u32| -> String {
192        serde_json::to_string(&WorkerToCoordinator::JobProgress {
193            worker_id: wid.into(),
194            job_id: jid.into(),
195            phase: phase.into(),
196            step,
197            total,
198            metrics: serde_json::json!({}),
199        })
200        .unwrap_or_default()
201    };
202
203    // Phase 1/4: Environment setup
204    messages.push(progress(&worker_id, &job.job_id, "environment", 1, 4));
205
206    let python = match state
207        .env_manager
208        .ensure_env(&job.pipeline_id, &job.requirements)
209    {
210        Ok(p) => p,
211        Err(e) => {
212            tracing::error!("Failed to create env for pipeline {}: {e}", job.pipeline_id);
213            let msg = WorkerToCoordinator::JobResult {
214                worker_id,
215                job_id: job.job_id.clone(),
216                success: false,
217                metrics: serde_json::json!({}),
218                output: format!("Environment setup failed: {e}"),
219                duration_ms: start.elapsed().as_millis() as u64,
220            };
221            messages.push(serde_json::to_string(&msg).unwrap_or_default());
222            return messages;
223        }
224    };
225
226    // Phase 2/4: Write files
227    messages.push(progress(&worker_id, &job.job_id, "write_files", 2, 4));
228
229    let job_dir = state.work_dir.join(format!("job-{}", job.job_id));
230    if let Err(e) = std::fs::create_dir_all(&job_dir) {
231        let msg = WorkerToCoordinator::JobResult {
232            worker_id,
233            job_id: job.job_id.clone(),
234            success: false,
235            metrics: serde_json::json!({}),
236            output: format!("Failed to create work dir: {e}"),
237            duration_ms: start.elapsed().as_millis() as u64,
238        };
239        messages.push(serde_json::to_string(&msg).unwrap_or_default());
240        return messages;
241    }
242
243    for file in &job.files {
244        let file_path = job_dir.join(&file.path);
245        if let Some(parent) = file_path.parent() {
246            std::fs::create_dir_all(parent).ok();
247        }
248        if let Err(e) = std::fs::write(&file_path, &file.content) {
249            tracing::error!("Failed to write {}: {e}", file.path);
250        }
251    }
252
253    // Phase 3/4: Execute
254    messages.push(progress(&worker_id, &job.job_id, "execute", 3, 4));
255
256    tracing::info!(
257        "Executing job {} with python: {}",
258        job.job_id,
259        python.display()
260    );
261
262    let output = std::process::Command::new(&python)
263        .arg(&job.entry_point)
264        .current_dir(&job_dir)
265        .env("PYTHONPATH", &job_dir)
266        .output();
267
268    let duration_ms = start.elapsed().as_millis() as u64;
269
270    // Phase 4/4: Collect results
271    let _ = std::fs::remove_dir_all(&job_dir);
272    messages.push(progress(&worker_id, &job.job_id, "collect_results", 4, 4));
273
274    let result_msg = match output {
275        Ok(out) => {
276            let stdout = String::from_utf8_lossy(&out.stdout).to_string();
277            let stderr = String::from_utf8_lossy(&out.stderr).to_string();
278            let success = out.status.success();
279
280            let metrics = stdout
281                .lines()
282                .rev()
283                .find_map(|line| serde_json::from_str::<serde_json::Value>(line).ok())
284                .unwrap_or(serde_json::json!({}));
285
286            if !success {
287                tracing::warn!(
288                    "Job {} failed: {}",
289                    job.job_id,
290                    stderr.chars().take(200).collect::<String>()
291                );
292            }
293
294            WorkerToCoordinator::JobResult {
295                worker_id,
296                job_id: job.job_id.clone(),
297                success,
298                metrics,
299                output: if success {
300                    stdout
301                } else {
302                    format!("STDERR:\n{stderr}\nSTDOUT:\n{stdout}")
303                },
304                duration_ms,
305            }
306        }
307        Err(e) => WorkerToCoordinator::JobResult {
308            worker_id,
309            job_id: job.job_id.clone(),
310            success: false,
311            metrics: serde_json::json!({}),
312            output: format!("Failed to execute: {e}"),
313            duration_ms,
314        },
315    };
316    messages.push(serde_json::to_string(&result_msg).unwrap_or_default());
317    messages
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::protocol::Capabilities;
324    fn make_worker() -> Worker {
325        Worker::new(
326            "test_worker",
327            Capabilities {
328                cpu_cores: 4,
329                ram_bytes: 8_000_000_000,
330                gpus: vec![],
331                python_envs: vec![],
332                tags: vec!["test".into()],
333            },
334        )
335    }
336
337    #[test]
338    fn router_builds() {
339        let _router = worker_router(make_worker());
340    }
341
342    #[tokio::test]
343    async fn health_returns_ok() {
344        let resp = health().await;
345        assert_eq!(resp, "ok");
346    }
347
348    #[tokio::test]
349    async fn full_server_starts_and_stops() {
350        let worker = make_worker();
351        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
352        let addr = listener.local_addr().unwrap();
353
354        let server = tokio::spawn(async move {
355            axum::serve(listener, worker_router(worker)).await.unwrap();
356        });
357
358        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
359
360        let client = reqwest::Client::new();
361        let resp = client
362            .get(format!("http://{addr}/health"))
363            .send()
364            .await
365            .unwrap();
366        assert_eq!(resp.text().await.unwrap(), "ok");
367
368        let resp = client
369            .get(format!("http://{addr}/info"))
370            .send()
371            .await
372            .unwrap();
373        let json: serde_json::Value = resp.json().await.unwrap();
374        assert!(json.get("type").is_some() || json.get("worker_id").is_some());
375
376        server.abort();
377    }
378}