Skip to main content

somatize_worker/
protocol.rs

1//! Wire protocol for coordinator ↔ worker communication.
2//!
3//! Defines message types for plan assignment, results, heartbeats,
4//! Python job management, and worker capabilities.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::DataRef;
11use somatize_core::value::Value;
12
13/// Unique worker identifier.
14pub type WorkerId = String;
15
16/// Unique plan execution identifier.
17pub type PlanId = String;
18
19/// Hardware and software capabilities of a worker.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22    /// Number of CPU cores.
23    pub cpu_cores: usize,
24    /// Total RAM in bytes.
25    pub ram_bytes: u64,
26    /// GPU information.
27    pub gpus: Vec<GpuInfo>,
28    /// Available Python environments.
29    pub python_envs: Vec<String>,
30    /// User-defined tags for routing (e.g. "gpu", "training", "inference").
31    pub tags: Vec<String>,
32}
33
34/// GPU hardware info.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37    pub name: String,
38    pub memory_bytes: u64,
39}
40
41/// Current load metrics reported by a worker.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44    pub cpu_usage: f32,
45    pub memory_usage: f32,
46    pub gpu_usage: Vec<f32>,
47    pub active_plans: usize,
48    pub queue_depth: usize,
49    pub timestamp: DateTime<Utc>,
50}
51
52/// How input data is provided to a worker.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57    /// Data embedded directly in the message (small payloads).
58    Inline { value: Value },
59    /// Data referenced in a remote store (large payloads).
60    Reference { data_ref: DataRef },
61}
62
63/// A serialized filter: cloudpickle bytes to reconstruct on the worker.
64///
65/// Uses cloudpickle (like Spark/Dask/Ray) to serialize the full Python object
66/// including bytecode, closures, and cross-module dependencies.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SerializedFilter {
69    /// Node ID this filter is registered under.
70    pub node_id: String,
71    /// cloudpickle.dumps() bytes (base64-encoded for JSON transport).
72    #[serde(with = "base64_bytes")]
73    pub pickled_filter: Vec<u8>,
74    /// Trained state (if fitted).
75    pub state: Option<Value>,
76    /// Pip requirements detected from the filter's imports (e.g. ["torch", "transformers"]).
77    #[serde(default)]
78    pub requirements: Vec<String>,
79}
80
81/// Serde helper: Vec<u8> ↔ base64 string for JSON-safe binary transport.
82mod base64_bytes {
83    use base64::engine::{Engine, general_purpose::STANDARD};
84    use serde::{Deserialize, Deserializer, Serialize, Serializer};
85
86    pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
87        STANDARD.encode(bytes).serialize(s)
88    }
89
90    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
91        let s = String::deserialize(d)?;
92        STANDARD.decode(s).map_err(serde::de::Error::custom)
93    }
94}
95
96/// Execution mode: fit (training) or forward (inference).
97#[derive(Debug, Clone, Serialize, Deserialize, Default)]
98#[non_exhaustive]
99pub enum ExecutionMode {
100    /// Training: fit each filter, then forward to propagate outputs.
101    Fit {
102        /// Supervised labels (optional).
103        y: Option<Value>,
104    },
105    /// Inference: forward only (default).
106    #[default]
107    Forward,
108}
109
110/// A serialized plan ready for remote execution.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct SerializedPlan {
113    pub plan_id: PlanId,
114    pub plan: ExecutionPlan,
115    /// Input data — inline for small values, DataRef for large ones.
116    pub input: Option<InputSource>,
117    /// Filter definitions for the worker to reconstruct.
118    #[serde(default)]
119    pub filters: Vec<SerializedFilter>,
120    /// Fit or Forward.
121    #[serde(default)]
122    pub mode: ExecutionMode,
123    pub metadata: serde_json::Value,
124}
125
126/// Messages from Worker → Coordinator.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128#[serde(tag = "type")]
129pub enum WorkerToCoordinator {
130    /// Worker announces itself.
131    Register {
132        worker_id: WorkerId,
133        capabilities: Capabilities,
134    },
135
136    /// Periodic health check.
137    Heartbeat {
138        worker_id: WorkerId,
139        load: LoadMetrics,
140    },
141
142    /// Execution event streamed back in real-time.
143    Event {
144        worker_id: WorkerId,
145        plan_id: PlanId,
146        event: Event,
147    },
148
149    /// Plan execution completed.
150    PlanResult {
151        worker_id: WorkerId,
152        plan_id: PlanId,
153        result: PlanResult,
154    },
155
156    /// Python job progress update.
157    JobProgress {
158        worker_id: WorkerId,
159        job_id: String,
160        phase: String,
161        step: u32,
162        total: u32,
163        metrics: serde_json::Value,
164    },
165
166    /// Python job result.
167    JobResult {
168        worker_id: WorkerId,
169        job_id: String,
170        success: bool,
171        metrics: serde_json::Value,
172        output: String,
173        duration_ms: u64,
174    },
175}
176
177/// A Python pipeline job: source files + requirements for isolated execution.
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct PythonPipelineJob {
180    pub job_id: String,
181    pub pipeline_id: String,
182    pub investigation_id: String,
183    /// Source files: path → content
184    pub files: Vec<PipelineFile>,
185    /// pip requirements (content of requirements.txt)
186    pub requirements: String,
187    /// Entry point: which file/function to execute
188    pub entry_point: String,
189    /// Input data (JSON-serialized)
190    pub input_data: Option<serde_json::Value>,
191    /// Extra parameters
192    pub params: serde_json::Value,
193}
194
195/// A source file in a pipeline job.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct PipelineFile {
198    pub path: String,
199    pub content: String,
200}
201
202/// Messages from Coordinator → Worker.
203#[derive(Debug, Clone, Serialize, Deserialize)]
204#[serde(tag = "type")]
205pub enum CoordinatorToWorker {
206    /// Accept worker registration.
207    Registered { worker_id: WorkerId },
208
209    /// Assign a native Soma plan for execution.
210    AssignPlan { plan: SerializedPlan },
211
212    /// Assign a Python pipeline job (with environment isolation).
213    AssignPythonJob { job: PythonPipelineJob },
214
215    /// Cancel a running plan/job.
216    CancelPlan { plan_id: PlanId },
217
218    /// Request current status.
219    StatusRequest,
220
221    /// Ping for keepalive.
222    Ping,
223
224    /// Graceful shutdown: worker should finish running plans and exit.
225    Shutdown { reason: String },
226}
227
228/// Result of a plan execution.
229#[derive(Debug, Clone, Serialize, Deserialize)]
230#[serde(tag = "status")]
231pub enum PlanResult {
232    Success {
233        output: Value,
234        duration_ms: u64,
235        /// Trained states returned after Fit mode (node_id → state).
236        /// Empty for Forward mode.
237        #[serde(default)]
238        states: std::collections::HashMap<String, Value>,
239    },
240    Failed {
241        error: String,
242        duration_ms: u64,
243    },
244}
245
246/// Streaming protocol: chunked data transfer over WebSocket Binary frames.
247///
248/// Wire format: msgpack-encoded StreamMessage (efficient binary, no JSON overhead).
249/// Client sends StreamBegin + N × ChunkData + StreamEnd.
250/// Worker responds with ChunkResult per chunk + StreamComplete at the end.
251#[derive(Debug, Clone, Serialize, Deserialize)]
252#[serde(tag = "type")]
253#[non_exhaustive]
254pub enum StreamMessage {
255    /// Begin a streaming session.
256    StreamBegin {
257        stream_id: String,
258        plan_id: PlanId,
259        /// Number of chunks (None if unknown ahead of time).
260        total_chunks: Option<usize>,
261        /// The plan to execute — input comes via chunks, not inline.
262        plan: Box<SerializedPlan>,
263    },
264    /// A single chunk of input data.
265    ChunkData {
266        stream_id: String,
267        chunk_index: usize,
268        value: Value,
269    },
270    /// All chunks have been sent.
271    StreamEnd { stream_id: String },
272    /// Result for a processed chunk (streamed back to client).
273    ChunkResult {
274        stream_id: String,
275        chunk_index: usize,
276        value: Value,
277    },
278    /// Final result after all chunks processed.
279    StreamComplete {
280        stream_id: String,
281        result: PlanResult,
282    },
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use somatize_core::event::PlanSummary;
289
290    #[test]
291    fn capabilities_serde() {
292        let caps = Capabilities {
293            cpu_cores: 8,
294            ram_bytes: 32 * 1024 * 1024 * 1024,
295            gpus: vec![GpuInfo {
296                name: "A100".into(),
297                memory_bytes: 80 * 1024 * 1024 * 1024,
298            }],
299            python_envs: vec!["py310".into(), "py311".into()],
300            tags: vec!["gpu".into(), "training".into()],
301        };
302        let json = serde_json::to_string(&caps).unwrap();
303        let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
304        assert_eq!(deserialized.cpu_cores, 8);
305        assert_eq!(deserialized.gpus.len(), 1);
306        assert_eq!(deserialized.tags, vec!["gpu", "training"]);
307    }
308
309    #[test]
310    fn worker_message_serde() {
311        let msg = WorkerToCoordinator::Register {
312            worker_id: "worker_01".into(),
313            capabilities: Capabilities {
314                cpu_cores: 4,
315                ram_bytes: 16_000_000_000,
316                gpus: vec![],
317                python_envs: vec![],
318                tags: vec!["cpu".into()],
319            },
320        };
321        let json = serde_json::to_string(&msg).unwrap();
322        assert!(json.contains("Register"));
323        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
324        if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
325            assert_eq!(worker_id, "worker_01");
326        } else {
327            panic!("wrong variant");
328        }
329    }
330
331    #[test]
332    fn coordinator_message_serde() {
333        let msg = CoordinatorToWorker::AssignPlan {
334            plan: SerializedPlan {
335                plan_id: "plan_001".into(),
336                plan: ExecutionPlan::Execute {
337                    node_id: "train".into(),
338                },
339                input: Some(InputSource::Inline {
340                    value: Value::tensor(vec![1.0, 2.0], vec![2]),
341                }),
342                filters: vec![],
343                mode: ExecutionMode::default(),
344                metadata: serde_json::json!({"experiment": "test"}),
345            },
346        };
347        let json = serde_json::to_string(&msg).unwrap();
348        let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
349        assert!(matches!(
350            deserialized,
351            CoordinatorToWorker::AssignPlan { .. }
352        ));
353    }
354
355    #[test]
356    fn plan_result_serde() {
357        let success = PlanResult::Success {
358            output: Value::tensor(vec![0.95], vec![1]),
359            duration_ms: 1234,
360            states: std::collections::HashMap::new(),
361        };
362        let json = serde_json::to_string(&success).unwrap();
363        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
364        assert!(matches!(deserialized, PlanResult::Success { .. }));
365
366        let failed = PlanResult::Failed {
367            error: "OOM".into(),
368            duration_ms: 500,
369        };
370        let json = serde_json::to_string(&failed).unwrap();
371        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
372        assert!(matches!(deserialized, PlanResult::Failed { .. }));
373    }
374
375    #[test]
376    fn event_message_serde() {
377        let msg = WorkerToCoordinator::Event {
378            worker_id: "w1".into(),
379            plan_id: "p1".into(),
380            event: Event::RunStarted {
381                run_id: "r1".into(),
382                plan_summary: PlanSummary {
383                    total_nodes: 3,
384                    cached_nodes: 1,
385                    parallel_branches: 0,
386                },
387            },
388        };
389        let json = serde_json::to_string(&msg).unwrap();
390        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
391        assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
392    }
393
394    #[test]
395    fn heartbeat_serde() {
396        let msg = WorkerToCoordinator::Heartbeat {
397            worker_id: "w1".into(),
398            load: LoadMetrics {
399                cpu_usage: 0.45,
400                memory_usage: 0.72,
401                gpu_usage: vec![0.88],
402                active_plans: 2,
403                queue_depth: 5,
404                timestamp: Utc::now(),
405            },
406        };
407        let json = serde_json::to_string(&msg).unwrap();
408        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
409        if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
410            assert!(load.cpu_usage > 0.0);
411            assert_eq!(load.active_plans, 2);
412        }
413    }
414}