Skip to main content

somatize_worker/
protocol.rs

1//! Wire protocol for coordinator ↔ worker communication.
2//!
3//! Defines message types for plan assignment, results, heartbeats,
4//! Python job management, and worker capabilities.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::DataRef;
11use somatize_core::value::Value;
12
13/// Unique worker identifier.
14pub type WorkerId = String;
15
16/// Unique plan execution identifier.
17pub type PlanId = String;
18
19/// Hardware and software capabilities of a worker.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22    /// Number of CPU cores.
23    pub cpu_cores: usize,
24    /// Total RAM in bytes.
25    pub ram_bytes: u64,
26    /// GPU information.
27    pub gpus: Vec<GpuInfo>,
28    /// Available Python environments.
29    pub python_envs: Vec<String>,
30    /// User-defined tags for routing (e.g. "gpu", "training", "inference").
31    pub tags: Vec<String>,
32}
33
34/// GPU hardware info.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37    pub name: String,
38    pub memory_bytes: u64,
39}
40
41/// Current load metrics reported by a worker.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44    pub cpu_usage: f32,
45    pub memory_usage: f32,
46    pub gpu_usage: Vec<f32>,
47    pub active_plans: usize,
48    pub queue_depth: usize,
49    pub timestamp: DateTime<Utc>,
50}
51
52/// How input data is provided to a worker.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57    /// Data embedded directly in the message (small payloads).
58    Inline { value: Value },
59    /// Data referenced in a remote store (large payloads).
60    Reference { data_ref: DataRef },
61}
62
63/// A serialized filter: cloudpickle bytes to reconstruct on the worker.
64///
65/// Uses cloudpickle (like Spark/Dask/Ray) to serialize the full Python object
66/// including bytecode, closures, and cross-module dependencies.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SerializedFilter {
69    /// Node ID this filter is registered under.
70    pub node_id: String,
71    /// cloudpickle.dumps() bytes (base64-encoded for JSON transport).
72    #[serde(with = "base64_bytes")]
73    pub pickled_filter: Vec<u8>,
74    /// Trained state (if fitted).
75    pub state: Option<Value>,
76    /// Pip requirements detected from the filter's imports (e.g. ["torch", "transformers"]).
77    #[serde(default)]
78    pub requirements: Vec<String>,
79    /// Whether the filter is trainable (has meaningful fit()) or stateless.
80    #[serde(default)]
81    pub trainable: bool,
82}
83
84/// Serde helper: Vec<u8> ↔ base64 string for JSON-safe binary transport.
85mod base64_bytes {
86    use base64::engine::{Engine, general_purpose::STANDARD};
87    use serde::{Deserialize, Deserializer, Serialize, Serializer};
88
89    pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
90        STANDARD.encode(bytes).serialize(s)
91    }
92
93    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
94        let s = String::deserialize(d)?;
95        STANDARD.decode(s).map_err(serde::de::Error::custom)
96    }
97}
98
99/// Execution mode: fit (training) or forward (inference).
100#[derive(Debug, Clone, Serialize, Deserialize, Default)]
101#[non_exhaustive]
102pub enum ExecutionMode {
103    /// Training: fit each filter, then forward to propagate outputs.
104    Fit {
105        /// Supervised labels (optional).
106        y: Option<Value>,
107    },
108    /// Inference: forward only (default).
109    #[default]
110    Forward,
111}
112
113/// A serialized plan ready for remote execution.
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct SerializedPlan {
116    pub plan_id: PlanId,
117    pub plan: ExecutionPlan,
118    /// Input data — inline for small values, DataRef for large ones.
119    pub input: Option<InputSource>,
120    /// Filter definitions for the worker to reconstruct.
121    #[serde(default)]
122    pub filters: Vec<SerializedFilter>,
123    /// Fit or Forward.
124    #[serde(default)]
125    pub mode: ExecutionMode,
126    pub metadata: serde_json::Value,
127}
128
129/// Messages from Worker → Coordinator.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131#[serde(tag = "type")]
132pub enum WorkerToCoordinator {
133    /// Worker announces itself.
134    Register {
135        worker_id: WorkerId,
136        capabilities: Capabilities,
137    },
138
139    /// Periodic health check.
140    Heartbeat {
141        worker_id: WorkerId,
142        load: LoadMetrics,
143    },
144
145    /// Execution event streamed back in real-time.
146    Event {
147        worker_id: WorkerId,
148        plan_id: PlanId,
149        event: Event,
150    },
151
152    /// Plan execution completed.
153    PlanResult {
154        worker_id: WorkerId,
155        plan_id: PlanId,
156        result: PlanResult,
157    },
158
159    /// Python job progress update.
160    JobProgress {
161        worker_id: WorkerId,
162        job_id: String,
163        phase: String,
164        step: u32,
165        total: u32,
166        metrics: serde_json::Value,
167    },
168
169    /// Python job result.
170    JobResult {
171        worker_id: WorkerId,
172        job_id: String,
173        success: bool,
174        metrics: serde_json::Value,
175        output: String,
176        duration_ms: u64,
177    },
178}
179
180/// A Python pipeline job: source files + requirements for isolated execution.
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct PythonPipelineJob {
183    pub job_id: String,
184    pub pipeline_id: String,
185    pub investigation_id: String,
186    /// Source files: path → content
187    pub files: Vec<PipelineFile>,
188    /// pip requirements (content of requirements.txt)
189    pub requirements: String,
190    /// Entry point: which file/function to execute
191    pub entry_point: String,
192    /// Input data (JSON-serialized)
193    pub input_data: Option<serde_json::Value>,
194    /// Extra parameters
195    pub params: serde_json::Value,
196}
197
198/// A source file in a pipeline job.
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct PipelineFile {
201    pub path: String,
202    pub content: String,
203}
204
205/// Messages from Coordinator → Worker.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207#[serde(tag = "type")]
208pub enum CoordinatorToWorker {
209    /// Accept worker registration.
210    Registered { worker_id: WorkerId },
211
212    /// Assign a native Soma plan for execution.
213    AssignPlan { plan: SerializedPlan },
214
215    /// Assign a Python pipeline job (with environment isolation).
216    AssignPythonJob { job: PythonPipelineJob },
217
218    /// Cancel a running plan/job.
219    CancelPlan { plan_id: PlanId },
220
221    /// Request current status.
222    StatusRequest,
223
224    /// Ping for keepalive.
225    Ping,
226
227    /// Graceful shutdown: worker should finish running plans and exit.
228    Shutdown { reason: String },
229}
230
231/// Result of a plan execution.
232#[derive(Debug, Clone, Serialize, Deserialize)]
233#[serde(tag = "status")]
234pub enum PlanResult {
235    Success {
236        output: Value,
237        duration_ms: u64,
238        /// Trained states returned after Fit mode (node_id → state).
239        /// Empty for Forward mode.
240        #[serde(default)]
241        states: std::collections::HashMap<String, Value>,
242    },
243    Failed {
244        error: String,
245        duration_ms: u64,
246    },
247}
248
249/// Streaming protocol: chunked data transfer over WebSocket Binary frames.
250///
251/// Wire format: msgpack-encoded StreamMessage (efficient binary, no JSON overhead).
252/// Client sends StreamBegin + N × ChunkData + StreamEnd.
253/// Worker responds with ChunkResult per chunk + StreamComplete at the end.
254#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(tag = "type")]
256#[non_exhaustive]
257pub enum StreamMessage {
258    /// Begin a streaming session.
259    StreamBegin {
260        stream_id: String,
261        plan_id: PlanId,
262        /// Number of chunks (None if unknown ahead of time).
263        total_chunks: Option<usize>,
264        /// The plan to execute — input comes via chunks, not inline.
265        plan: Box<SerializedPlan>,
266    },
267    /// A single chunk of input data.
268    ChunkData {
269        stream_id: String,
270        chunk_index: usize,
271        value: Value,
272    },
273    /// All chunks have been sent.
274    StreamEnd { stream_id: String },
275    /// Result for a processed chunk (streamed back to client).
276    ChunkResult {
277        stream_id: String,
278        chunk_index: usize,
279        value: Value,
280    },
281    /// Final result after all chunks processed.
282    StreamComplete {
283        stream_id: String,
284        result: PlanResult,
285    },
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use somatize_core::event::PlanSummary;
292
293    #[test]
294    fn capabilities_serde() {
295        let caps = Capabilities {
296            cpu_cores: 8,
297            ram_bytes: 32 * 1024 * 1024 * 1024,
298            gpus: vec![GpuInfo {
299                name: "A100".into(),
300                memory_bytes: 80 * 1024 * 1024 * 1024,
301            }],
302            python_envs: vec!["py310".into(), "py311".into()],
303            tags: vec!["gpu".into(), "training".into()],
304        };
305        let json = serde_json::to_string(&caps).unwrap();
306        let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
307        assert_eq!(deserialized.cpu_cores, 8);
308        assert_eq!(deserialized.gpus.len(), 1);
309        assert_eq!(deserialized.tags, vec!["gpu", "training"]);
310    }
311
312    #[test]
313    fn worker_message_serde() {
314        let msg = WorkerToCoordinator::Register {
315            worker_id: "worker_01".into(),
316            capabilities: Capabilities {
317                cpu_cores: 4,
318                ram_bytes: 16_000_000_000,
319                gpus: vec![],
320                python_envs: vec![],
321                tags: vec!["cpu".into()],
322            },
323        };
324        let json = serde_json::to_string(&msg).unwrap();
325        assert!(json.contains("Register"));
326        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
327        if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
328            assert_eq!(worker_id, "worker_01");
329        } else {
330            panic!("wrong variant");
331        }
332    }
333
334    #[test]
335    fn coordinator_message_serde() {
336        let msg = CoordinatorToWorker::AssignPlan {
337            plan: SerializedPlan {
338                plan_id: "plan_001".into(),
339                plan: ExecutionPlan::Execute {
340                    node_id: "train".into(),
341                },
342                input: Some(InputSource::Inline {
343                    value: Value::tensor(vec![1.0, 2.0], vec![2]),
344                }),
345                filters: vec![],
346                mode: ExecutionMode::default(),
347                metadata: serde_json::json!({"experiment": "test"}),
348            },
349        };
350        let json = serde_json::to_string(&msg).unwrap();
351        let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
352        assert!(matches!(
353            deserialized,
354            CoordinatorToWorker::AssignPlan { .. }
355        ));
356    }
357
358    #[test]
359    fn plan_result_serde() {
360        let success = PlanResult::Success {
361            output: Value::tensor(vec![0.95], vec![1]),
362            duration_ms: 1234,
363            states: std::collections::HashMap::new(),
364        };
365        let json = serde_json::to_string(&success).unwrap();
366        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
367        assert!(matches!(deserialized, PlanResult::Success { .. }));
368
369        let failed = PlanResult::Failed {
370            error: "OOM".into(),
371            duration_ms: 500,
372        };
373        let json = serde_json::to_string(&failed).unwrap();
374        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
375        assert!(matches!(deserialized, PlanResult::Failed { .. }));
376    }
377
378    #[test]
379    fn event_message_serde() {
380        let msg = WorkerToCoordinator::Event {
381            worker_id: "w1".into(),
382            plan_id: "p1".into(),
383            event: Event::RunStarted {
384                run_id: "r1".into(),
385                plan_summary: PlanSummary {
386                    total_nodes: 3,
387                    cached_nodes: 1,
388                    parallel_branches: 0,
389                },
390            },
391        };
392        let json = serde_json::to_string(&msg).unwrap();
393        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
394        assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
395    }
396
397    #[test]
398    fn heartbeat_serde() {
399        let msg = WorkerToCoordinator::Heartbeat {
400            worker_id: "w1".into(),
401            load: LoadMetrics {
402                cpu_usage: 0.45,
403                memory_usage: 0.72,
404                gpu_usage: vec![0.88],
405                active_plans: 2,
406                queue_depth: 5,
407                timestamp: Utc::now(),
408            },
409        };
410        let json = serde_json::to_string(&msg).unwrap();
411        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
412        if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
413            assert!(load.cpu_usage > 0.0);
414            assert_eq!(load.active_plans, 2);
415        }
416    }
417}