Skip to main content

somatize_worker/
server.rs

1//! Axum HTTP/WebSocket server for the worker process.
2//!
3//! Supports optional bearer token authentication on WebSocket connections.
4//! Set a token via [`worker_router_authenticated`] or the `--token` CLI flag.
5
6use crate::env_manager::{EnvManager, EnvType};
7use crate::protocol::*;
8use crate::worker::Worker;
9use axum::Router;
10use axum::extract::DefaultBodyLimit;
11use axum::extract::ws::{Message, WebSocket};
12use axum::extract::{Query, State, WebSocketUpgrade};
13use axum::http::StatusCode;
14use axum::response::IntoResponse;
15use axum::routing::{get, post};
16use somatize_core::cache::CacheKey;
17use somatize_core::store::{DataStore, LocalDataStore};
18use somatize_core::value::Value;
19use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::{Arc, Mutex};
22use std::time::Instant;
23
24/// Shared state for the worker HTTP/WebSocket server.
25struct ServerState {
26    worker: Mutex<Worker>,
27    env_manager: EnvManager,
28    work_dir: PathBuf,
29    /// Optional bearer token for authentication.
30    token: Option<String>,
31    /// Temporary local store for HTTP bulk uploads.
32    temp_store: Arc<LocalDataStore>,
33    /// Track upload times for automatic cleanup.
34    temp_uploads: Mutex<HashMap<CacheKey, Instant>>,
35    /// Active streaming sessions: stream_id → (executor, start_time).
36    active_streams:
37        Mutex<HashMap<String, (somatize_runtime::executors::stream::StreamExecutor, Instant)>>,
38}
39
40/// Build a worker server router (no authentication).
41pub fn worker_router(worker: Worker) -> Router {
42    worker_router_full(worker, "/tmp/soma-envs", "/tmp/soma-work", None)
43}
44
45/// Build a worker server router with custom directories.
46pub fn worker_router_with_dirs(
47    worker: Worker,
48    env_dir: impl Into<PathBuf>,
49    work_dir: impl Into<PathBuf>,
50) -> Router {
51    worker_router_full(worker, env_dir, work_dir, None)
52}
53
54/// Build a worker server router with authentication.
55pub fn worker_router_authenticated(
56    worker: Worker,
57    env_dir: impl Into<PathBuf>,
58    work_dir: impl Into<PathBuf>,
59    token: impl Into<String>,
60) -> Router {
61    worker_router_full(worker, env_dir, work_dir, Some(token.into()))
62}
63
64fn worker_router_full(
65    worker: Worker,
66    env_dir: impl Into<PathBuf>,
67    work_dir: impl Into<PathBuf>,
68    token: Option<String>,
69) -> Router {
70    let work = work_dir.into();
71    std::fs::create_dir_all(&work).ok();
72    let temp_store = worker.temp_store().clone();
73    let state = Arc::new(ServerState {
74        worker: Mutex::new(worker),
75        env_manager: EnvManager::new(env_dir, EnvType::Venv),
76        work_dir: work,
77        token,
78        temp_store,
79        temp_uploads: Mutex::new(HashMap::new()),
80        active_streams: Mutex::new(HashMap::new()),
81    });
82    // Background cleanup: remove temp uploads older than 1 hour
83    let cleanup_state = state.clone();
84    tokio::spawn(async move {
85        let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
86        loop {
87            interval.tick().await;
88            let cutoff = Instant::now() - std::time::Duration::from_secs(3600);
89            let expired: Vec<CacheKey> = {
90                let uploads = cleanup_state
91                    .temp_uploads
92                    .lock()
93                    .unwrap_or_else(|e| e.into_inner());
94                uploads
95                    .iter()
96                    .filter(|(_, created)| **created < cutoff)
97                    .map(|(k, _)| k.clone())
98                    .collect()
99            };
100            if !expired.is_empty() {
101                let mut uploads = cleanup_state
102                    .temp_uploads
103                    .lock()
104                    .unwrap_or_else(|e| e.into_inner());
105                for key in &expired {
106                    let data_ref = somatize_core::store::DataRef::Cached {
107                        cache_key: key.clone(),
108                    };
109                    let _ = cleanup_state.temp_store.remove(&data_ref);
110                    uploads.remove(key);
111                }
112                tracing::info!("Cleaned up {} expired temp uploads", expired.len());
113            }
114        }
115    });
116
117    Router::new()
118        .route("/health", get(health))
119        .route("/info", get(info))
120        .route("/upload", post(upload_data))
121        .route("/download", get(download_data))
122        .route("/ws", get(ws_handler))
123        .layer(DefaultBodyLimit::disable()) // No limit — workers handle arbitrary data sizes
124        .with_state(state)
125}
126
127/// Start a worker server on the given address.
128pub async fn serve_worker(worker: Worker, addr: &str) -> Result<(), Box<dyn std::error::Error>> {
129    let listener = tokio::net::TcpListener::bind(addr).await?;
130    tracing::info!("Worker server listening on {addr}");
131    axum::serve(listener, worker_router(worker)).await?;
132    Ok(())
133}
134
135/// Start a worker server with authentication.
136pub async fn serve_worker_authenticated(
137    worker: Worker,
138    addr: &str,
139    token: &str,
140) -> Result<(), Box<dyn std::error::Error>> {
141    let listener = tokio::net::TcpListener::bind(addr).await?;
142    tracing::info!("Worker server listening on {addr} (authenticated)");
143    let router = worker_router_authenticated(worker, "/tmp/soma-envs", "/tmp/soma-work", token);
144    axum::serve(listener, router).await?;
145    Ok(())
146}
147
148async fn health() -> &'static str {
149    "ok"
150}
151
152async fn info(State(state): State<Arc<ServerState>>) -> impl IntoResponse {
153    let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
154    let msg = worker.registration_message();
155    axum::Json(serde_json::to_value(msg).unwrap_or_default())
156}
157
158/// Upload data via HTTP for large payloads that exceed WebSocket limits.
159///
160/// Accepts msgpack or JSON body, stores in temp_store, returns DataRef as JSON.
161/// Token auth via `?token=` query param (same as WebSocket).
162async fn upload_data(
163    Query(params): Query<WsParams>,
164    State(state): State<Arc<ServerState>>,
165    body: axum::body::Bytes,
166) -> Result<impl IntoResponse, StatusCode> {
167    // Validate token
168    if let Some(expected) = &state.token {
169        match &params.token {
170            Some(provided) if provided == expected => {}
171            _ => return Err(StatusCode::UNAUTHORIZED),
172        }
173    }
174
175    // Deserialize: try msgpack first, then JSON
176    let value: Value = rmp_serde::from_slice(&body)
177        .or_else(|_| serde_json::from_slice(&body))
178        .map_err(|_| StatusCode::BAD_REQUEST)?;
179
180    let key = CacheKey::hash_data(&body);
181    let data_ref = state
182        .temp_store
183        .put(&key, &value)
184        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
185
186    // Track for cleanup
187    state
188        .temp_uploads
189        .lock()
190        .unwrap_or_else(|e| e.into_inner())
191        .insert(key, Instant::now());
192
193    tracing::info!("Uploaded {} bytes → {data_ref:?}", body.len());
194
195    Ok(axum::Json(
196        serde_json::to_value(&data_ref).unwrap_or_default(),
197    ))
198}
199
200/// Query params for data download.
201#[derive(serde::Deserialize)]
202struct DownloadParams {
203    /// JSON-serialized DataRef (same format returned by /upload).
204    #[serde(rename = "ref")]
205    data_ref: String,
206    token: Option<String>,
207}
208
209/// Download data from the worker's temp store by DataRef.
210///
211/// Returns msgpack-encoded Value. Used by clients to resolve
212/// `OutputDelivery::Reference` results after plan execution.
213async fn download_data(
214    Query(params): Query<DownloadParams>,
215    State(state): State<Arc<ServerState>>,
216) -> Result<impl IntoResponse, StatusCode> {
217    // Validate token
218    if let Some(expected) = &state.token {
219        match &params.token {
220            Some(provided) if provided == expected => {}
221            _ => return Err(StatusCode::UNAUTHORIZED),
222        }
223    }
224
225    let data_ref: somatize_core::store::DataRef =
226        serde_json::from_str(&params.data_ref).map_err(|_| StatusCode::BAD_REQUEST)?;
227
228    let value = state
229        .temp_store
230        .get(&data_ref)
231        .map_err(|_| StatusCode::NOT_FOUND)?;
232
233    let bytes = serde_json::to_vec(&value).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
234
235    Ok((
236        [(axum::http::header::CONTENT_TYPE, "application/json")],
237        bytes,
238    ))
239}
240
241/// Query params for WebSocket authentication.
242#[derive(serde::Deserialize, Default)]
243struct WsParams {
244    token: Option<String>,
245}
246
247async fn ws_handler(
248    ws: WebSocketUpgrade,
249    Query(params): Query<WsParams>,
250    State(state): State<Arc<ServerState>>,
251) -> Result<impl IntoResponse, StatusCode> {
252    // Validate token if server requires one
253    if let Some(expected) = &state.token {
254        match &params.token {
255            Some(provided) if provided == expected => {}
256            _ => return Err(StatusCode::UNAUTHORIZED),
257        }
258    }
259    Ok(ws
260        .max_message_size(usize::MAX) // No limit on incoming WS messages
261        .max_frame_size(usize::MAX)
262        .on_upgrade(move |socket| handle_ws(socket, state)))
263}
264
265async fn handle_ws(mut socket: WebSocket, state: Arc<ServerState>) {
266    loop {
267        match socket.recv().await {
268            Some(Ok(Message::Text(text))) => {
269                let response = match serde_json::from_str::<CoordinatorToWorker>(&text) {
270                    Ok(CoordinatorToWorker::AssignPlan { plan }) => {
271                        let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
272                        let plan_id = plan.plan_id.clone();
273                        let worker_id = worker.id.clone();
274                        let result = worker.execute_plan(&plan);
275                        let msg = WorkerToCoordinator::PlanResult {
276                            worker_id,
277                            plan_id,
278                            result,
279                        };
280                        serde_json::to_string(&msg).unwrap_or_default()
281                    }
282                    Ok(CoordinatorToWorker::StatusRequest) => {
283                        let worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
284                        serde_json::to_string(&worker.registration_message()).unwrap_or_default()
285                    }
286                    Ok(CoordinatorToWorker::CancelPlan { .. }) => {
287                        r#"{"status": "cancel_not_implemented"}"#.to_string()
288                    }
289                    Ok(CoordinatorToWorker::AssignPythonJob { job }) => {
290                        // Send progress messages during execution
291                        let messages = execute_python_job_with_progress(&state, &job);
292                        // Send all but the last as intermediate messages
293                        for msg in &messages[..messages.len().saturating_sub(1)] {
294                            if socket
295                                .send(Message::Text(msg.clone().into()))
296                                .await
297                                .is_err()
298                            {
299                                break;
300                            }
301                        }
302                        // Return the last message (result) through the normal path
303                        messages.into_iter().last().unwrap_or_default()
304                    }
305                    Ok(CoordinatorToWorker::Ping) => r#"{"type":"Pong"}"#.to_string(),
306                    Ok(CoordinatorToWorker::Registered { .. }) => continue,
307                    Ok(CoordinatorToWorker::Shutdown { reason }) => {
308                        tracing::info!("Shutdown requested: {reason}");
309                        let _ = socket
310                            .send(Message::Text(r#"{"type":"ShutdownAck"}"#.into()))
311                            .await;
312                        std::process::exit(0);
313                    }
314                    Ok(
315                        CoordinatorToWorker::GetState { .. }
316                        | CoordinatorToWorker::SetState { .. }
317                        | CoordinatorToWorker::GetGradients { .. }
318                        | CoordinatorToWorker::ApplyGradients { .. },
319                    ) => {
320                        r#"{"error":"get_state/set_state/get_gradients/apply_gradients not yet implemented for SubprocessFilter"}"#.to_string()
321                    }
322                    Err(e) => {
323                        format!(r#"{{"error": "invalid message: {e}"}}"#)
324                    }
325                };
326
327                if socket.send(Message::Text(response.into())).await.is_err() {
328                    break;
329                }
330            }
331            Some(Ok(Message::Binary(bytes))) => {
332                if let Ok(stream_msg) = rmp_serde::from_slice::<StreamMessage>(&bytes) {
333                    let reply = handle_stream_message(stream_msg, &state);
334                    if let Some(reply_msg) = reply {
335                        let reply_bytes = rmp_serde::to_vec(&reply_msg).unwrap_or_default();
336                        if socket
337                            .send(Message::Binary(reply_bytes.into()))
338                            .await
339                            .is_err()
340                        {
341                            break;
342                        }
343                    }
344                }
345            }
346            Some(Ok(Message::Close(_))) | None => break,
347            _ => {}
348        }
349    }
350}
351
352/// Handle a streaming protocol message. Returns an optional reply.
353fn handle_stream_message(msg: StreamMessage, state: &Arc<ServerState>) -> Option<StreamMessage> {
354    use somatize_runtime::executors::stream::{FittedFilter, StreamExecutor};
355
356    match msg {
357        StreamMessage::StreamBegin {
358            stream_id, plan, ..
359        } => {
360            // Build StreamExecutor from the plan's filters
361            let mut worker = state.worker.lock().unwrap_or_else(|e| e.into_inner());
362
363            // Register filters via SubprocessFilter backed by a shared PythonProcess
364            let filter_specs: Vec<(String, Vec<u8>, bool)> = plan
365                .filters
366                .iter()
367                .map(|sf| (sf.node_id.clone(), sf.pickled_filter.clone(), sf.trainable))
368                .collect();
369
370            if !filter_specs.is_empty() {
371                let process = Arc::new(std::sync::Mutex::new(
372                    crate::python_process::PythonProcess::spawn("python3", &filter_specs)
373                        .expect("PythonProcess spawn failed"),
374                ));
375
376                for sf in &plan.filters {
377                    let filter: Box<dyn somatize_core::filter::Filter> =
378                        Box::new(crate::python_process::SubprocessFilter::new(
379                            process.clone(),
380                            sf.node_id.clone(),
381                            sf.trainable,
382                        ));
383                    worker.register_filter(&sf.node_id, filter);
384                    if let Some(s) = &sf.state {
385                        worker.set_filter_state(&sf.node_id, s.clone());
386                    }
387                }
388            }
389
390            // Build FittedFilter list from plan node order
391            let node_ids = plan.plan.node_ids();
392            let fitted: Vec<FittedFilter> = node_ids
393                .iter()
394                .filter_map(|id| {
395                    let filter = worker.get_filter(id)?;
396                    let filter_state = worker.get_filter_state(id);
397                    Some(FittedFilter {
398                        name: id.to_string(),
399                        filter,
400                        state: filter_state,
401                    })
402                })
403                .collect();
404
405            let executor = StreamExecutor::new(fitted);
406            state
407                .active_streams
408                .lock()
409                .unwrap_or_else(|e| e.into_inner())
410                .insert(stream_id, (executor, Instant::now()));
411
412            None // No reply for StreamBegin
413        }
414        StreamMessage::ChunkData {
415            stream_id,
416            chunk_index,
417            value,
418        } => {
419            let mut streams = state
420                .active_streams
421                .lock()
422                .unwrap_or_else(|e| e.into_inner());
423            if let Some((executor, _)) = streams.get_mut(&stream_id) {
424                match executor.process_chunk(value) {
425                    Ok(Some(result)) => Some(StreamMessage::ChunkResult {
426                        stream_id,
427                        chunk_index,
428                        value: result,
429                    }),
430                    Ok(None) => None, // Barrier mode — no result yet
431                    Err(e) => Some(StreamMessage::StreamComplete {
432                        stream_id,
433                        result: PlanResult::Failed {
434                            error: e.to_string(),
435                            duration_ms: 0,
436                        },
437                    }),
438                }
439            } else {
440                Some(StreamMessage::StreamComplete {
441                    stream_id,
442                    result: PlanResult::Failed {
443                        error: "unknown stream_id".to_string(),
444                        duration_ms: 0,
445                    },
446                })
447            }
448        }
449        StreamMessage::StreamEnd { stream_id } => {
450            let mut streams = state
451                .active_streams
452                .lock()
453                .unwrap_or_else(|e| e.into_inner());
454            if let Some((mut executor, start)) = streams.remove(&stream_id) {
455                let duration_ms = start.elapsed().as_millis() as u64;
456                // Flush barrier filters — only Barrier mode accumulates data.
457                let output = executor
458                    .flush()
459                    .unwrap_or(None)
460                    .unwrap_or(somatize_core::value::Value::Empty);
461                Some(StreamMessage::StreamComplete {
462                    stream_id,
463                    result: PlanResult::Success {
464                        output: OutputDelivery::Inline { value: output },
465                        duration_ms,
466                        states: std::collections::HashMap::new(),
467                    },
468                })
469            } else {
470                None
471            }
472        }
473        _ => None,
474    }
475}
476
477/// Execute a Python pipeline job with progress reporting.
478fn execute_python_job_with_progress(state: &ServerState, job: &PythonPipelineJob) -> Vec<String> {
479    let start = Instant::now();
480    let mut messages = Vec::new();
481    let worker_id = {
482        let w = state.worker.lock().unwrap_or_else(|e| e.into_inner());
483        w.id.clone()
484    };
485
486    let progress = |wid: &str, jid: &str, phase: &str, step: u32, total: u32| -> String {
487        serde_json::to_string(&WorkerToCoordinator::JobProgress {
488            worker_id: wid.into(),
489            job_id: jid.into(),
490            phase: phase.into(),
491            step,
492            total,
493            metrics: serde_json::json!({}),
494        })
495        .unwrap_or_default()
496    };
497
498    // Phase 1/4: Environment setup
499    messages.push(progress(&worker_id, &job.job_id, "environment", 1, 4));
500
501    let python = match state
502        .env_manager
503        .ensure_env(&job.pipeline_id, &job.requirements)
504    {
505        Ok(p) => p,
506        Err(e) => {
507            tracing::error!("Failed to create env for pipeline {}: {e}", job.pipeline_id);
508            let msg = WorkerToCoordinator::JobResult {
509                worker_id,
510                job_id: job.job_id.clone(),
511                success: false,
512                metrics: serde_json::json!({}),
513                output: format!("Environment setup failed: {e}"),
514                duration_ms: start.elapsed().as_millis() as u64,
515            };
516            messages.push(serde_json::to_string(&msg).unwrap_or_default());
517            return messages;
518        }
519    };
520
521    // Phase 2/4: Write files
522    messages.push(progress(&worker_id, &job.job_id, "write_files", 2, 4));
523
524    let job_dir = state.work_dir.join(format!("job-{}", job.job_id));
525    if let Err(e) = std::fs::create_dir_all(&job_dir) {
526        let msg = WorkerToCoordinator::JobResult {
527            worker_id,
528            job_id: job.job_id.clone(),
529            success: false,
530            metrics: serde_json::json!({}),
531            output: format!("Failed to create work dir: {e}"),
532            duration_ms: start.elapsed().as_millis() as u64,
533        };
534        messages.push(serde_json::to_string(&msg).unwrap_or_default());
535        return messages;
536    }
537
538    for file in &job.files {
539        let file_path = job_dir.join(&file.path);
540        if let Some(parent) = file_path.parent() {
541            std::fs::create_dir_all(parent).ok();
542        }
543        if let Err(e) = std::fs::write(&file_path, &file.content) {
544            tracing::error!("Failed to write {}: {e}", file.path);
545        }
546    }
547
548    // Phase 3/4: Execute
549    messages.push(progress(&worker_id, &job.job_id, "execute", 3, 4));
550
551    tracing::info!(
552        "Executing job {} with python: {}",
553        job.job_id,
554        python.display()
555    );
556
557    let output = std::process::Command::new(&python)
558        .arg(&job.entry_point)
559        .current_dir(&job_dir)
560        .env("PYTHONPATH", &job_dir)
561        .output();
562
563    let duration_ms = start.elapsed().as_millis() as u64;
564
565    // Phase 4/4: Collect results
566    let _ = std::fs::remove_dir_all(&job_dir);
567    messages.push(progress(&worker_id, &job.job_id, "collect_results", 4, 4));
568
569    let result_msg = match output {
570        Ok(out) => {
571            let stdout = String::from_utf8_lossy(&out.stdout).to_string();
572            let stderr = String::from_utf8_lossy(&out.stderr).to_string();
573            let success = out.status.success();
574
575            let metrics = stdout
576                .lines()
577                .rev()
578                .find_map(|line| serde_json::from_str::<serde_json::Value>(line).ok())
579                .unwrap_or(serde_json::json!({}));
580
581            if !success {
582                tracing::warn!(
583                    "Job {} failed: {}",
584                    job.job_id,
585                    stderr.chars().take(200).collect::<String>()
586                );
587            }
588
589            WorkerToCoordinator::JobResult {
590                worker_id,
591                job_id: job.job_id.clone(),
592                success,
593                metrics,
594                output: if success {
595                    stdout
596                } else {
597                    format!("STDERR:\n{stderr}\nSTDOUT:\n{stdout}")
598                },
599                duration_ms,
600            }
601        }
602        Err(e) => WorkerToCoordinator::JobResult {
603            worker_id,
604            job_id: job.job_id.clone(),
605            success: false,
606            metrics: serde_json::json!({}),
607            output: format!("Failed to execute: {e}"),
608            duration_ms,
609        },
610    };
611    messages.push(serde_json::to_string(&result_msg).unwrap_or_default());
612    messages
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618    use crate::protocol::Capabilities;
619    fn make_worker() -> Worker {
620        Worker::new(
621            "test_worker",
622            Capabilities {
623                cpu_cores: 4,
624                ram_bytes: 8_000_000_000,
625                gpus: vec![],
626                python_envs: vec![],
627                tags: vec!["test".into()],
628            },
629        )
630    }
631
632    #[tokio::test]
633    async fn router_builds() {
634        let _router = worker_router(make_worker());
635    }
636
637    #[tokio::test]
638    async fn health_returns_ok() {
639        let resp = health().await;
640        assert_eq!(resp, "ok");
641    }
642
643    #[tokio::test]
644    async fn full_server_starts_and_stops() {
645        let worker = make_worker();
646        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
647        let addr = listener.local_addr().unwrap();
648
649        let server = tokio::spawn(async move {
650            axum::serve(listener, worker_router(worker)).await.unwrap();
651        });
652
653        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
654
655        let client = reqwest::Client::new();
656        let resp = client
657            .get(format!("http://{addr}/health"))
658            .send()
659            .await
660            .unwrap();
661        assert_eq!(resp.text().await.unwrap(), "ok");
662
663        let resp = client
664            .get(format!("http://{addr}/info"))
665            .send()
666            .await
667            .unwrap();
668        let json: serde_json::Value = resp.json().await.unwrap();
669        assert!(json.get("type").is_some() || json.get("worker_id").is_some());
670
671        server.abort();
672    }
673}