Skip to main content

somatize_worker/
server.rs

1//! Axum HTTP/WebSocket server for the worker process.
2//!
3//! Supports optional bearer token authentication on WebSocket connections.
4//! Set a token via [`worker_router_authenticated`] or the `--token` CLI flag.
5
6use crate::env_manager::{EnvManager, EnvType};
7use crate::protocol::*;
8use crate::worker::Worker;
9use axum::Router;
10use axum::extract::ws::{Message, WebSocket};
11use axum::extract::{Query, State, WebSocketUpgrade};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::routing::{get, post};
15use somatize_core::cache::CacheKey;
16use somatize_core::store::{DataStore, LocalDataStore};
17use somatize_core::value::Value;
18use std::collections::HashMap;
19use std::path::PathBuf;
20use std::sync::{Arc, Mutex};
21use std::time::Instant;
22
23/// Shared state for the worker HTTP/WebSocket server.
24struct ServerState {
25    worker: Mutex<Worker>,
26    env_manager: EnvManager,
27    work_dir: PathBuf,
28    /// Optional bearer token for authentication.
29    token: Option<String>,
30    /// Temporary local store for HTTP bulk uploads.
31    temp_store: Arc<LocalDataStore>,
32    /// Track upload times for automatic cleanup.
33    temp_uploads: Mutex<HashMap<CacheKey, Instant>>,
34    /// Active streaming sessions: stream_id → StreamExecutor.
35    active_streams: Mutex<HashMap<String, somatize_runtime::stream::StreamExecutor>>,
36}
37
38/// Build a worker server router (no authentication).
39pub fn worker_router(worker: Worker) -> Router {
40    worker_router_full(worker, "/tmp/soma-envs", "/tmp/soma-work", None)
41}
42
43/// Build a worker server router with custom directories.
44pub fn worker_router_with_dirs(
45    worker: Worker,
46    env_dir: impl Into<PathBuf>,
47    work_dir: impl Into<PathBuf>,
48) -> Router {
49    worker_router_full(worker, env_dir, work_dir, None)
50}
51
52/// Build a worker server router with authentication.
53pub fn worker_router_authenticated(
54    worker: Worker,
55    env_dir: impl Into<PathBuf>,
56    work_dir: impl Into<PathBuf>,
57    token: impl Into<String>,
58) -> Router {
59    worker_router_full(worker, env_dir, work_dir, Some(token.into()))
60}
61
62fn worker_router_full(
63    worker: Worker,
64    env_dir: impl Into<PathBuf>,
65    work_dir: impl Into<PathBuf>,
66    token: Option<String>,
67) -> Router {
68    let work = work_dir.into();
69    std::fs::create_dir_all(&work).ok();
70    let temp_store = worker.temp_store().clone();
71    let state = Arc::new(ServerState {
72        worker: Mutex::new(worker),
73        env_manager: EnvManager::new(env_dir, EnvType::Venv),
74        work_dir: work,
75        token,
76        temp_store,
77        temp_uploads: Mutex::new(HashMap::new()),
78        active_streams: Mutex::new(HashMap::new()),
79    });
80    // Background cleanup: remove temp uploads older than 1 hour
81    let cleanup_state = state.clone();
82    tokio::spawn(async move {
83        let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
84        loop {
85            interval.tick().await;
86            let cutoff = Instant::now() - std::time::Duration::from_secs(3600);
87            let expired: Vec<CacheKey> = {
88                let uploads = cleanup_state
89                    .temp_uploads
90                    .lock()
91                    .unwrap_or_else(|e| e.into_inner());
92                uploads
93                    .iter()
94                    .filter(|(_, created)| **created < cutoff)
95                    .map(|(k, _)| k.clone())
96                    .collect()
97            };
98            if !expired.is_empty() {
99                let mut uploads = cleanup_state
100                    .temp_uploads
101                    .lock()
102                    .unwrap_or_else(|e| e.into_inner());
103                for key in &expired {
104                    let data_ref = somatize_core::store::DataRef::Cached {
105                        cache_key: key.clone(),
106                    };
107                    let _ = cleanup_state.temp_store.remove(&data_ref);
108                    uploads.remove(key);
109                }
110                tracing::info!("Cleaned up {} expired temp uploads", expired.len());
111            }
112        }
113    });
114
115    Router::new()
116        .route("/health", get(health))
117        .route("/info", get(info))
118        .route("/upload", post(upload_data))
119        .route("/ws", get(ws_handler))
120        .with_state(state)
121}
122
123/// Start a worker server on the given address.
124pub async fn serve_worker(worker: Worker, addr: &str) -> Result<(), Box<dyn std::error::Error>> {
125    let listener = tokio::net::TcpListener::bind(addr).await?;
126    tracing::info!("Worker server listening on {addr}");
127    axum::serve(listener, worker_router(worker)).await?;
128    Ok(())
129}
130
131/// Start a worker server with authentication.
132pub async fn serve_worker_authenticated(
133    worker: Worker,
134    addr: &str,
135    token: &str,
136) -> Result<(), Box<dyn std::error::Error>> {
137    let listener = tokio::net::TcpListener::bind(addr).await?;
138    tracing::info!("Worker server listening on {addr} (authenticated)");
139    let router = worker_router_authenticated(worker, "/tmp/soma-envs", "/tmp/soma-work", token);
140    axum::serve(listener, router).await?;
141    Ok(())
142}
143
144async fn health() -> &'static str {
145    "ok"
146}
147
148async fn info(State(state): State<Arc<ServerState>>) -> impl IntoResponse {
149    let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
150    let msg = worker.registration_message();
151    axum::Json(serde_json::to_value(msg).unwrap_or_default())
152}
153
154/// Upload data via HTTP for large payloads that exceed WebSocket limits.
155///
156/// Accepts msgpack or JSON body, stores in temp_store, returns DataRef as JSON.
157/// Token auth via `?token=` query param (same as WebSocket).
158async fn upload_data(
159    Query(params): Query<WsParams>,
160    State(state): State<Arc<ServerState>>,
161    body: axum::body::Bytes,
162) -> Result<impl IntoResponse, StatusCode> {
163    // Validate token
164    if let Some(expected) = &state.token {
165        match &params.token {
166            Some(provided) if provided == expected => {}
167            _ => return Err(StatusCode::UNAUTHORIZED),
168        }
169    }
170
171    // Deserialize: try msgpack first, then JSON
172    let value: Value = rmp_serde::from_slice(&body)
173        .or_else(|_| serde_json::from_slice(&body))
174        .map_err(|_| StatusCode::BAD_REQUEST)?;
175
176    let key = CacheKey::hash_data(&body);
177    let data_ref = state
178        .temp_store
179        .put(&key, &value)
180        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
181
182    // Track for cleanup
183    state
184        .temp_uploads
185        .lock()
186        .unwrap_or_else(|e| e.into_inner())
187        .insert(key, Instant::now());
188
189    tracing::info!("Uploaded {} bytes → {data_ref:?}", body.len());
190
191    Ok(axum::Json(
192        serde_json::to_value(&data_ref).unwrap_or_default(),
193    ))
194}
195
196/// Query params for WebSocket authentication.
197#[derive(serde::Deserialize, Default)]
198struct WsParams {
199    token: Option<String>,
200}
201
202async fn ws_handler(
203    ws: WebSocketUpgrade,
204    Query(params): Query<WsParams>,
205    State(state): State<Arc<ServerState>>,
206) -> Result<impl IntoResponse, StatusCode> {
207    // Validate token if server requires one
208    if let Some(expected) = &state.token {
209        match &params.token {
210            Some(provided) if provided == expected => {}
211            _ => return Err(StatusCode::UNAUTHORIZED),
212        }
213    }
214    Ok(ws.on_upgrade(move |socket| handle_ws(socket, state)))
215}
216
217async fn handle_ws(mut socket: WebSocket, state: Arc<ServerState>) {
218    loop {
219        match socket.recv().await {
220            Some(Ok(Message::Text(text))) => {
221                let response = match serde_json::from_str::<CoordinatorToWorker>(&text) {
222                    Ok(CoordinatorToWorker::AssignPlan { plan }) => {
223                        let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
224                        let plan_id = plan.plan_id.clone();
225                        let worker_id = worker.id.clone();
226                        let result = worker.execute_plan(&plan);
227                        let msg = WorkerToCoordinator::PlanResult {
228                            worker_id,
229                            plan_id,
230                            result,
231                        };
232                        serde_json::to_string(&msg).unwrap_or_default()
233                    }
234                    Ok(CoordinatorToWorker::StatusRequest) => {
235                        let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
236                        serde_json::to_string(&worker.registration_message()).unwrap_or_default()
237                    }
238                    Ok(CoordinatorToWorker::CancelPlan { .. }) => {
239                        r#"{"status": "cancel_not_implemented"}"#.to_string()
240                    }
241                    Ok(CoordinatorToWorker::AssignPythonJob { job }) => {
242                        // Send progress messages during execution
243                        let messages = execute_python_job_with_progress(&state, &job);
244                        // Send all but the last as intermediate messages
245                        for msg in &messages[..messages.len().saturating_sub(1)] {
246                            if socket
247                                .send(Message::Text(msg.clone().into()))
248                                .await
249                                .is_err()
250                            {
251                                break;
252                            }
253                        }
254                        // Return the last message (result) through the normal path
255                        messages.into_iter().last().unwrap_or_default()
256                    }
257                    Ok(CoordinatorToWorker::Ping) => r#"{"type":"Pong"}"#.to_string(),
258                    Ok(CoordinatorToWorker::Registered { .. }) => continue,
259                    Ok(CoordinatorToWorker::Shutdown { reason }) => {
260                        tracing::info!("Shutdown requested: {reason}");
261                        let _ = socket
262                            .send(Message::Text(r#"{"type":"ShutdownAck"}"#.into()))
263                            .await;
264                        std::process::exit(0);
265                    }
266                    Err(e) => {
267                        format!(r#"{{"error": "invalid message: {e}"}}"#)
268                    }
269                };
270
271                if socket.send(Message::Text(response.into())).await.is_err() {
272                    break;
273                }
274            }
275            Some(Ok(Message::Binary(bytes))) => {
276                if let Ok(stream_msg) = rmp_serde::from_slice::<StreamMessage>(&bytes) {
277                    let reply = handle_stream_message(stream_msg, &state);
278                    if let Some(reply_msg) = reply {
279                        let reply_bytes = rmp_serde::to_vec(&reply_msg).unwrap_or_default();
280                        if socket
281                            .send(Message::Binary(reply_bytes.into()))
282                            .await
283                            .is_err()
284                        {
285                            break;
286                        }
287                    }
288                }
289            }
290            Some(Ok(Message::Close(_))) | None => break,
291            _ => {}
292        }
293    }
294}
295
296/// Handle a streaming protocol message. Returns an optional reply.
297fn handle_stream_message(msg: StreamMessage, state: &Arc<ServerState>) -> Option<StreamMessage> {
298    use somatize_runtime::stream::{FittedFilter, StreamExecutor};
299
300    match msg {
301        StreamMessage::StreamBegin {
302            stream_id, plan, ..
303        } => {
304            // Build StreamExecutor from the plan's filters
305            let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
306
307            // Register pickled filters (streaming uses system python — venv managed by execute_plan)
308            for sf in &plan.filters {
309                let filter = Box::new(crate::worker::PickledFilterRunner {
310                    pickled_bytes: sf.pickled_filter.clone(),
311                    node_id: sf.node_id.clone(),
312                    python_path: "python3".to_string(),
313                    requirements: sf.requirements.clone(),
314                });
315                worker.register_filter(&sf.node_id, filter);
316                if let Some(s) = &sf.state {
317                    worker.set_filter_state(&sf.node_id, s.clone());
318                }
319            }
320
321            // Build FittedFilter list from plan node order
322            let node_ids = plan.plan.node_ids();
323            let fitted: Vec<FittedFilter> = node_ids
324                .iter()
325                .filter_map(|id| {
326                    let filter = worker.get_filter(id)?;
327                    let filter_state = worker.get_filter_state(id);
328                    Some(FittedFilter {
329                        name: id.to_string(),
330                        filter,
331                        state: filter_state,
332                    })
333                })
334                .collect();
335
336            let executor = StreamExecutor::new(fitted);
337            state
338                .active_streams
339                .lock()
340                .unwrap_or_else(|e| e.into_inner())
341                .insert(stream_id, executor);
342
343            None // No reply for StreamBegin
344        }
345        StreamMessage::ChunkData {
346            stream_id,
347            chunk_index,
348            value,
349        } => {
350            let mut streams = state
351                .active_streams
352                .lock()
353                .unwrap_or_else(|e| e.into_inner());
354            if let Some(executor) = streams.get_mut(&stream_id) {
355                match executor.process_chunk(value) {
356                    Ok(Some(result)) => Some(StreamMessage::ChunkResult {
357                        stream_id,
358                        chunk_index,
359                        value: result,
360                    }),
361                    Ok(None) => None, // Barrier mode — no result yet
362                    Err(e) => Some(StreamMessage::StreamComplete {
363                        stream_id,
364                        result: PlanResult::Failed {
365                            error: e.to_string(),
366                            duration_ms: 0,
367                        },
368                    }),
369                }
370            } else {
371                Some(StreamMessage::StreamComplete {
372                    stream_id,
373                    result: PlanResult::Failed {
374                        error: "unknown stream_id".to_string(),
375                        duration_ms: 0,
376                    },
377                })
378            }
379        }
380        StreamMessage::StreamEnd { stream_id } => {
381            let mut streams = state
382                .active_streams
383                .lock()
384                .unwrap_or_else(|e| e.into_inner());
385            if let Some(mut executor) = streams.remove(&stream_id) {
386                // Flush barrier filters
387                let output = executor
388                    .flush()
389                    .unwrap_or(None)
390                    .unwrap_or(somatize_core::value::Value::Empty);
391                Some(StreamMessage::StreamComplete {
392                    stream_id,
393                    result: PlanResult::Success {
394                        output,
395                        duration_ms: 0,
396                        states: std::collections::HashMap::new(),
397                    },
398                })
399            } else {
400                None
401            }
402        }
403        _ => None,
404    }
405}
406
407/// Execute a Python pipeline job with progress reporting.
408fn execute_python_job_with_progress(state: &ServerState, job: &PythonPipelineJob) -> Vec<String> {
409    let start = Instant::now();
410    let mut messages = Vec::new();
411    let worker_id = {
412        let w = state.worker.lock().unwrap_or_else(|e| e.into_inner());
413        w.id.clone()
414    };
415
416    let progress = |wid: &str, jid: &str, phase: &str, step: u32, total: u32| -> String {
417        serde_json::to_string(&WorkerToCoordinator::JobProgress {
418            worker_id: wid.into(),
419            job_id: jid.into(),
420            phase: phase.into(),
421            step,
422            total,
423            metrics: serde_json::json!({}),
424        })
425        .unwrap_or_default()
426    };
427
428    // Phase 1/4: Environment setup
429    messages.push(progress(&worker_id, &job.job_id, "environment", 1, 4));
430
431    let python = match state
432        .env_manager
433        .ensure_env(&job.pipeline_id, &job.requirements)
434    {
435        Ok(p) => p,
436        Err(e) => {
437            tracing::error!("Failed to create env for pipeline {}: {e}", job.pipeline_id);
438            let msg = WorkerToCoordinator::JobResult {
439                worker_id,
440                job_id: job.job_id.clone(),
441                success: false,
442                metrics: serde_json::json!({}),
443                output: format!("Environment setup failed: {e}"),
444                duration_ms: start.elapsed().as_millis() as u64,
445            };
446            messages.push(serde_json::to_string(&msg).unwrap_or_default());
447            return messages;
448        }
449    };
450
451    // Phase 2/4: Write files
452    messages.push(progress(&worker_id, &job.job_id, "write_files", 2, 4));
453
454    let job_dir = state.work_dir.join(format!("job-{}", job.job_id));
455    if let Err(e) = std::fs::create_dir_all(&job_dir) {
456        let msg = WorkerToCoordinator::JobResult {
457            worker_id,
458            job_id: job.job_id.clone(),
459            success: false,
460            metrics: serde_json::json!({}),
461            output: format!("Failed to create work dir: {e}"),
462            duration_ms: start.elapsed().as_millis() as u64,
463        };
464        messages.push(serde_json::to_string(&msg).unwrap_or_default());
465        return messages;
466    }
467
468    for file in &job.files {
469        let file_path = job_dir.join(&file.path);
470        if let Some(parent) = file_path.parent() {
471            std::fs::create_dir_all(parent).ok();
472        }
473        if let Err(e) = std::fs::write(&file_path, &file.content) {
474            tracing::error!("Failed to write {}: {e}", file.path);
475        }
476    }
477
478    // Phase 3/4: Execute
479    messages.push(progress(&worker_id, &job.job_id, "execute", 3, 4));
480
481    tracing::info!(
482        "Executing job {} with python: {}",
483        job.job_id,
484        python.display()
485    );
486
487    let output = std::process::Command::new(&python)
488        .arg(&job.entry_point)
489        .current_dir(&job_dir)
490        .env("PYTHONPATH", &job_dir)
491        .output();
492
493    let duration_ms = start.elapsed().as_millis() as u64;
494
495    // Phase 4/4: Collect results
496    let _ = std::fs::remove_dir_all(&job_dir);
497    messages.push(progress(&worker_id, &job.job_id, "collect_results", 4, 4));
498
499    let result_msg = match output {
500        Ok(out) => {
501            let stdout = String::from_utf8_lossy(&out.stdout).to_string();
502            let stderr = String::from_utf8_lossy(&out.stderr).to_string();
503            let success = out.status.success();
504
505            let metrics = stdout
506                .lines()
507                .rev()
508                .find_map(|line| serde_json::from_str::<serde_json::Value>(line).ok())
509                .unwrap_or(serde_json::json!({}));
510
511            if !success {
512                tracing::warn!(
513                    "Job {} failed: {}",
514                    job.job_id,
515                    stderr.chars().take(200).collect::<String>()
516                );
517            }
518
519            WorkerToCoordinator::JobResult {
520                worker_id,
521                job_id: job.job_id.clone(),
522                success,
523                metrics,
524                output: if success {
525                    stdout
526                } else {
527                    format!("STDERR:\n{stderr}\nSTDOUT:\n{stdout}")
528                },
529                duration_ms,
530            }
531        }
532        Err(e) => WorkerToCoordinator::JobResult {
533            worker_id,
534            job_id: job.job_id.clone(),
535            success: false,
536            metrics: serde_json::json!({}),
537            output: format!("Failed to execute: {e}"),
538            duration_ms,
539        },
540    };
541    messages.push(serde_json::to_string(&result_msg).unwrap_or_default());
542    messages
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use crate::protocol::Capabilities;
549    fn make_worker() -> Worker {
550        Worker::new(
551            "test_worker",
552            Capabilities {
553                cpu_cores: 4,
554                ram_bytes: 8_000_000_000,
555                gpus: vec![],
556                python_envs: vec![],
557                tags: vec!["test".into()],
558            },
559        )
560    }
561
562    #[tokio::test]
563    async fn router_builds() {
564        let _router = worker_router(make_worker());
565    }
566
567    #[tokio::test]
568    async fn health_returns_ok() {
569        let resp = health().await;
570        assert_eq!(resp, "ok");
571    }
572
573    #[tokio::test]
574    async fn full_server_starts_and_stops() {
575        let worker = make_worker();
576        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
577        let addr = listener.local_addr().unwrap();
578
579        let server = tokio::spawn(async move {
580            axum::serve(listener, worker_router(worker)).await.unwrap();
581        });
582
583        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
584
585        let client = reqwest::Client::new();
586        let resp = client
587            .get(format!("http://{addr}/health"))
588            .send()
589            .await
590            .unwrap();
591        assert_eq!(resp.text().await.unwrap(), "ok");
592
593        let resp = client
594            .get(format!("http://{addr}/info"))
595            .send()
596            .await
597            .unwrap();
598        let json: serde_json::Value = resp.json().await.unwrap();
599        assert!(json.get("type").is_some() || json.get("worker_id").is_some());
600
601        server.abort();
602    }
603}