Skip to main content

somatize_worker/
protocol.rs

1//! Wire protocol for coordinator ↔ worker communication.
2//!
3//! Defines message types for plan assignment, results, heartbeats,
4//! Python job management, and worker capabilities.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::{DataRef, DataStore};
11use somatize_core::value::Value;
12
13/// Unique worker identifier.
14pub type WorkerId = String;
15
16/// Unique plan execution identifier.
17pub type PlanId = String;
18
19/// Hardware and software capabilities of a worker.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22    /// Number of CPU cores.
23    pub cpu_cores: usize,
24    /// Total RAM in bytes.
25    pub ram_bytes: u64,
26    /// GPU information.
27    pub gpus: Vec<GpuInfo>,
28    /// Available Python environments.
29    pub python_envs: Vec<String>,
30    /// User-defined tags for routing (e.g. "gpu", "training", "inference").
31    pub tags: Vec<String>,
32}
33
34/// GPU hardware info.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37    pub name: String,
38    pub memory_bytes: u64,
39}
40
41/// Current load metrics reported by a worker.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44    pub cpu_usage: f32,
45    pub memory_usage: f32,
46    pub gpu_usage: Vec<f32>,
47    pub active_plans: usize,
48    pub queue_depth: usize,
49    pub timestamp: DateTime<Utc>,
50}
51
52/// How input data is provided to a worker.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57    /// Data embedded directly in the message (small payloads).
58    Inline { value: Value },
59    /// Data referenced in a remote store (large payloads).
60    Reference { data_ref: DataRef },
61}
62
63impl InputSource {
64    /// Resolve the input to a concrete Value.
65    /// Tries persistent DataStore first, then temp store for HTTP uploads.
66    pub fn resolve(
67        &self,
68        data_store: Option<&dyn somatize_core::store::DataStore>,
69        temp_store: &somatize_core::store::LocalDataStore,
70    ) -> Value {
71        match self {
72            InputSource::Inline { value } => value.clone(),
73            InputSource::Reference { data_ref } => {
74                if let Some(store) = data_store
75                    && let Ok(val) = store.get(data_ref)
76                {
77                    return val;
78                }
79                temp_store.get(data_ref).unwrap_or_else(|e| {
80                    tracing::warn!("Failed to resolve DataRef: {e}");
81                    Value::Empty
82                })
83            }
84        }
85    }
86}
87
88/// A serialized filter: cloudpickle bytes to reconstruct on the worker.
89///
90/// Uses cloudpickle (like Spark/Dask/Ray) to serialize the full Python object
91/// including bytecode, closures, and cross-module dependencies.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct SerializedFilter {
94    /// Node ID this filter is registered under.
95    pub node_id: String,
96    /// cloudpickle.dumps() bytes (base64-encoded for JSON transport).
97    #[serde(with = "base64_bytes")]
98    pub pickled_filter: Vec<u8>,
99    /// Trained state (if fitted).
100    pub state: Option<Value>,
101    /// Pip requirements detected from the filter's imports (e.g. ["torch", "transformers"]).
102    #[serde(default)]
103    pub requirements: Vec<String>,
104    /// Whether the filter is trainable (has meaningful fit()) or stateless.
105    #[serde(default)]
106    pub trainable: bool,
107}
108
109/// Serde helper: Vec<u8> ↔ base64 string for JSON-safe binary transport.
110mod base64_bytes {
111    use base64::engine::{Engine, general_purpose::STANDARD};
112    use serde::{Deserialize, Deserializer, Serialize, Serializer};
113
114    pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
115        STANDARD.encode(bytes).serialize(s)
116    }
117
118    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
119        let s = String::deserialize(d)?;
120        STANDARD.decode(s).map_err(serde::de::Error::custom)
121    }
122}
123
124/// Execution mode: fit (training) or forward (inference).
125#[derive(Debug, Clone, Serialize, Deserialize, Default)]
126#[non_exhaustive]
127pub enum ExecutionMode {
128    /// Training: fit each filter, then forward to propagate outputs.
129    Fit {
130        /// Supervised labels (optional).
131        y: Option<Value>,
132        /// If set, the worker splits the input into batches internally.
133        /// Model is loaded once, batches processed in a loop.
134        #[serde(default)]
135        batch_size: Option<usize>,
136    },
137    /// Inference: forward only (default).
138    #[default]
139    Forward,
140}
141
142/// A serialized plan ready for remote execution.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct SerializedPlan {
145    pub plan_id: PlanId,
146    pub plan: ExecutionPlan,
147    /// Input data — inline for small values, DataRef for large ones.
148    pub input: Option<InputSource>,
149    /// Filter definitions for the worker to reconstruct.
150    #[serde(default)]
151    pub filters: Vec<SerializedFilter>,
152    /// Fit or Forward.
153    #[serde(default)]
154    pub mode: ExecutionMode,
155    pub metadata: serde_json::Value,
156}
157
158/// Messages from Worker → Coordinator.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160#[serde(tag = "type")]
161pub enum WorkerToCoordinator {
162    /// Worker announces itself.
163    Register {
164        worker_id: WorkerId,
165        capabilities: Capabilities,
166    },
167
168    /// Periodic health check.
169    Heartbeat {
170        worker_id: WorkerId,
171        load: LoadMetrics,
172    },
173
174    /// Execution event streamed back in real-time.
175    Event {
176        worker_id: WorkerId,
177        plan_id: PlanId,
178        event: Event,
179    },
180
181    /// Plan execution completed.
182    PlanResult {
183        worker_id: WorkerId,
184        plan_id: PlanId,
185        result: PlanResult,
186    },
187
188    /// Python job progress update.
189    JobProgress {
190        worker_id: WorkerId,
191        job_id: String,
192        phase: String,
193        step: u32,
194        total: u32,
195        metrics: serde_json::Value,
196    },
197
198    /// Python job result.
199    JobResult {
200        worker_id: WorkerId,
201        job_id: String,
202        success: bool,
203        metrics: serde_json::Value,
204        output: String,
205        duration_ms: u64,
206    },
207
208    // ── Distributed training responses ──
209    /// Response to GetState: trained filter states.
210    StateResult {
211        worker_id: WorkerId,
212        plan_id: PlanId,
213        states: std::collections::HashMap<String, Value>,
214    },
215
216    /// Response to GetGradients: gradient data.
217    GradientsResult {
218        worker_id: WorkerId,
219        plan_id: PlanId,
220        gradients: std::collections::HashMap<String, Value>,
221    },
222}
223
224/// A Python pipeline job: source files + requirements for isolated execution.
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct PythonPipelineJob {
227    pub job_id: String,
228    pub pipeline_id: String,
229    pub investigation_id: String,
230    /// Source files: path → content
231    pub files: Vec<PipelineFile>,
232    /// pip requirements (content of requirements.txt)
233    pub requirements: String,
234    /// Entry point: which file/function to execute
235    pub entry_point: String,
236    /// Input data (JSON-serialized)
237    pub input_data: Option<serde_json::Value>,
238    /// Extra parameters
239    pub params: serde_json::Value,
240}
241
242/// A source file in a pipeline job.
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct PipelineFile {
245    pub path: String,
246    pub content: String,
247}
248
249/// Messages from Coordinator → Worker.
250#[derive(Debug, Clone, Serialize, Deserialize)]
251#[serde(tag = "type")]
252pub enum CoordinatorToWorker {
253    /// Accept worker registration.
254    Registered { worker_id: WorkerId },
255
256    /// Assign a native Soma plan for execution.
257    AssignPlan { plan: SerializedPlan },
258
259    /// Assign a Python pipeline job (with environment isolation).
260    AssignPythonJob { job: PythonPipelineJob },
261
262    /// Cancel a running plan/job.
263    CancelPlan { plan_id: PlanId },
264
265    /// Request current status.
266    StatusRequest,
267
268    /// Ping for keepalive.
269    Ping,
270
271    /// Graceful shutdown: worker should finish running plans and exit.
272    Shutdown { reason: String },
273
274    // ── Distributed training messages ──
275    /// Request trained states from specific filters.
276    GetState {
277        plan_id: PlanId,
278        node_ids: Vec<String>,
279    },
280
281    /// Load states into filters (e.g. after FedAvg aggregation).
282    SetState {
283        plan_id: PlanId,
284        states: std::collections::HashMap<String, Value>,
285    },
286
287    /// Request gradients from filters (for AllReduce in DataParallel).
288    GetGradients {
289        plan_id: PlanId,
290        node_ids: Vec<String>,
291    },
292
293    /// Apply aggregated gradients (after AllReduce).
294    ApplyGradients {
295        plan_id: PlanId,
296        gradients: std::collections::HashMap<String, Value>,
297    },
298}
299
300/// How output is delivered in PlanResult.
301#[derive(Debug, Clone, Serialize, Deserialize)]
302#[serde(tag = "delivery")]
303#[non_exhaustive]
304pub enum OutputDelivery {
305    /// Small output — embedded directly in the WS message.
306    Inline { value: Value },
307    /// Large output — stored on worker, download via HTTP GET /download?key=...
308    Reference {
309        data_ref: somatize_core::store::DataRef,
310    },
311}
312
313impl OutputDelivery {
314    /// Resolve the output to a concrete Value.
315    /// For Reference: downloads via HTTP from the worker.
316    pub fn resolve(&self, addr: &str, token: &Option<String>) -> Value {
317        match self {
318            OutputDelivery::Inline { value } => value.clone(),
319            OutputDelivery::Reference { data_ref } => {
320                // HTTP download in a dedicated thread (avoids tokio nesting)
321                let http_addr = addr
322                    .replace("ws://", "http://")
323                    .replace("wss://", "https://");
324                let url = format!("{http_addr}/download");
325                let ref_json = serde_json::to_string(data_ref).unwrap_or_default();
326                let token = token.clone();
327
328                std::thread::spawn(move || {
329                    let client = reqwest::blocking::Client::new();
330                    let mut req = client.get(&url).query(&[("ref", &ref_json)]);
331                    if let Some(t) = &token {
332                        req = req.query(&[("token", t.as_str())]);
333                    }
334                    let resp = req.send().ok()?;
335                    let bytes = resp.bytes().ok()?;
336                    serde_json::from_slice(&bytes).ok()
337                })
338                .join()
339                .ok()
340                .flatten()
341                .unwrap_or(Value::Empty)
342            }
343        }
344    }
345}
346
347/// Result of a plan execution.
348#[derive(Debug, Clone, Serialize, Deserialize)]
349#[serde(tag = "status")]
350pub enum PlanResult {
351    Success {
352        output: OutputDelivery,
353        duration_ms: u64,
354        /// Trained states returned after Fit mode (node_id → state).
355        /// Empty for Forward mode.
356        #[serde(default)]
357        states: std::collections::HashMap<String, Value>,
358    },
359    Failed {
360        error: String,
361        duration_ms: u64,
362    },
363}
364
365/// Streaming protocol: chunked data transfer over WebSocket Binary frames.
366///
367/// Wire format: msgpack-encoded StreamMessage (efficient binary, no JSON overhead).
368/// Client sends StreamBegin + N × ChunkData + StreamEnd.
369/// Worker responds with ChunkResult per chunk + StreamComplete at the end.
370#[derive(Debug, Clone, Serialize, Deserialize)]
371#[serde(tag = "type")]
372#[non_exhaustive]
373pub enum StreamMessage {
374    /// Begin a streaming session.
375    StreamBegin {
376        stream_id: String,
377        plan_id: PlanId,
378        /// Number of chunks (None if unknown ahead of time).
379        total_chunks: Option<usize>,
380        /// The plan to execute — input comes via chunks, not inline.
381        plan: Box<SerializedPlan>,
382    },
383    /// A single chunk of input data.
384    ChunkData {
385        stream_id: String,
386        chunk_index: usize,
387        value: Value,
388    },
389    /// All chunks have been sent.
390    StreamEnd { stream_id: String },
391    /// Result for a processed chunk (streamed back to client).
392    ChunkResult {
393        stream_id: String,
394        chunk_index: usize,
395        value: Value,
396    },
397    /// Final result after all chunks processed.
398    StreamComplete {
399        stream_id: String,
400        result: PlanResult,
401    },
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use somatize_core::event::PlanSummary;
408
409    #[test]
410    fn capabilities_serde() {
411        let caps = Capabilities {
412            cpu_cores: 8,
413            ram_bytes: 32 * 1024 * 1024 * 1024,
414            gpus: vec![GpuInfo {
415                name: "A100".into(),
416                memory_bytes: 80 * 1024 * 1024 * 1024,
417            }],
418            python_envs: vec!["py310".into(), "py311".into()],
419            tags: vec!["gpu".into(), "training".into()],
420        };
421        let json = serde_json::to_string(&caps).unwrap();
422        let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
423        assert_eq!(deserialized.cpu_cores, 8);
424        assert_eq!(deserialized.gpus.len(), 1);
425        assert_eq!(deserialized.tags, vec!["gpu", "training"]);
426    }
427
428    #[test]
429    fn worker_message_serde() {
430        let msg = WorkerToCoordinator::Register {
431            worker_id: "worker_01".into(),
432            capabilities: Capabilities {
433                cpu_cores: 4,
434                ram_bytes: 16_000_000_000,
435                gpus: vec![],
436                python_envs: vec![],
437                tags: vec!["cpu".into()],
438            },
439        };
440        let json = serde_json::to_string(&msg).unwrap();
441        assert!(json.contains("Register"));
442        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
443        if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
444            assert_eq!(worker_id, "worker_01");
445        } else {
446            panic!("wrong variant");
447        }
448    }
449
450    #[test]
451    fn coordinator_message_serde() {
452        let msg = CoordinatorToWorker::AssignPlan {
453            plan: SerializedPlan {
454                plan_id: "plan_001".into(),
455                plan: ExecutionPlan::Execute {
456                    node_id: "train".into(),
457                },
458                input: Some(InputSource::Inline {
459                    value: Value::tensor(vec![1.0, 2.0], vec![2]),
460                }),
461                filters: vec![],
462                mode: ExecutionMode::default(),
463                metadata: serde_json::json!({"experiment": "test"}),
464            },
465        };
466        let json = serde_json::to_string(&msg).unwrap();
467        let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
468        assert!(matches!(
469            deserialized,
470            CoordinatorToWorker::AssignPlan { .. }
471        ));
472    }
473
474    #[test]
475    fn plan_result_serde() {
476        let success = PlanResult::Success {
477            output: OutputDelivery::Inline {
478                value: Value::tensor(vec![0.95], vec![1]),
479            },
480            duration_ms: 1234,
481            states: std::collections::HashMap::new(),
482        };
483        let json = serde_json::to_string(&success).unwrap();
484        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
485        assert!(matches!(deserialized, PlanResult::Success { .. }));
486
487        let failed = PlanResult::Failed {
488            error: "OOM".into(),
489            duration_ms: 500,
490        };
491        let json = serde_json::to_string(&failed).unwrap();
492        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
493        assert!(matches!(deserialized, PlanResult::Failed { .. }));
494    }
495
496    #[test]
497    fn event_message_serde() {
498        let msg = WorkerToCoordinator::Event {
499            worker_id: "w1".into(),
500            plan_id: "p1".into(),
501            event: Event::RunStarted {
502                run_id: "r1".into(),
503                plan_summary: PlanSummary {
504                    total_nodes: 3,
505                    cached_nodes: 1,
506                    parallel_branches: 0,
507                },
508            },
509        };
510        let json = serde_json::to_string(&msg).unwrap();
511        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
512        assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
513    }
514
515    #[test]
516    fn heartbeat_serde() {
517        let msg = WorkerToCoordinator::Heartbeat {
518            worker_id: "w1".into(),
519            load: LoadMetrics {
520                cpu_usage: 0.45,
521                memory_usage: 0.72,
522                gpu_usage: vec![0.88],
523                active_plans: 2,
524                queue_depth: 5,
525                timestamp: Utc::now(),
526            },
527        };
528        let json = serde_json::to_string(&msg).unwrap();
529        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
530        if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
531            assert!(load.cpu_usage > 0.0);
532            assert_eq!(load.active_plans, 2);
533        }
534    }
535}