Skip to main content

somatize_worker/
protocol.rs

1//! Wire protocol for coordinator ↔ worker communication.
2//!
3//! Defines message types for plan assignment, results, heartbeats,
4//! Python job management, and worker capabilities.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::DataRef;
11use somatize_core::value::Value;
12
13/// Unique worker identifier.
14pub type WorkerId = String;
15
16/// Unique plan execution identifier.
17pub type PlanId = String;
18
19/// Hardware and software capabilities of a worker.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22    /// Number of CPU cores.
23    pub cpu_cores: usize,
24    /// Total RAM in bytes.
25    pub ram_bytes: u64,
26    /// GPU information.
27    pub gpus: Vec<GpuInfo>,
28    /// Available Python environments.
29    pub python_envs: Vec<String>,
30    /// User-defined tags for routing (e.g. "gpu", "training", "inference").
31    pub tags: Vec<String>,
32}
33
34/// GPU hardware info.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37    pub name: String,
38    pub memory_bytes: u64,
39}
40
41/// Current load metrics reported by a worker.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44    pub cpu_usage: f32,
45    pub memory_usage: f32,
46    pub gpu_usage: Vec<f32>,
47    pub active_plans: usize,
48    pub queue_depth: usize,
49    pub timestamp: DateTime<Utc>,
50}
51
52/// How input data is provided to a worker.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57    /// Data embedded directly in the message (small payloads).
58    Inline { value: Value },
59    /// Data referenced in a remote store (large payloads).
60    Reference { data_ref: DataRef },
61}
62
63/// A serialized filter: cloudpickle bytes to reconstruct on the worker.
64///
65/// Uses cloudpickle (like Spark/Dask/Ray) to serialize the full Python object
66/// including bytecode, closures, and cross-module dependencies.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SerializedFilter {
69    /// Node ID this filter is registered under.
70    pub node_id: String,
71    /// cloudpickle.dumps() bytes (base64-encoded for JSON transport).
72    #[serde(with = "base64_bytes")]
73    pub pickled_filter: Vec<u8>,
74    /// Trained state (if fitted).
75    pub state: Option<Value>,
76}
77
78/// Serde helper: Vec<u8> ↔ base64 string for JSON-safe binary transport.
79mod base64_bytes {
80    use base64::engine::{Engine, general_purpose::STANDARD};
81    use serde::{Deserialize, Deserializer, Serialize, Serializer};
82
83    pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
84        STANDARD.encode(bytes).serialize(s)
85    }
86
87    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
88        let s = String::deserialize(d)?;
89        STANDARD.decode(s).map_err(serde::de::Error::custom)
90    }
91}
92
93/// A serialized plan ready for remote execution.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct SerializedPlan {
96    pub plan_id: PlanId,
97    pub plan: ExecutionPlan,
98    /// Input data — inline for small values, DataRef for large ones.
99    pub input: Option<InputSource>,
100    /// Filter definitions for the worker to reconstruct.
101    #[serde(default)]
102    pub filters: Vec<SerializedFilter>,
103    pub metadata: serde_json::Value,
104}
105
106/// Messages from Worker → Coordinator.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108#[serde(tag = "type")]
109pub enum WorkerToCoordinator {
110    /// Worker announces itself.
111    Register {
112        worker_id: WorkerId,
113        capabilities: Capabilities,
114    },
115
116    /// Periodic health check.
117    Heartbeat {
118        worker_id: WorkerId,
119        load: LoadMetrics,
120    },
121
122    /// Execution event streamed back in real-time.
123    Event {
124        worker_id: WorkerId,
125        plan_id: PlanId,
126        event: Event,
127    },
128
129    /// Plan execution completed.
130    PlanResult {
131        worker_id: WorkerId,
132        plan_id: PlanId,
133        result: PlanResult,
134    },
135
136    /// Python job progress update.
137    JobProgress {
138        worker_id: WorkerId,
139        job_id: String,
140        phase: String,
141        step: u32,
142        total: u32,
143        metrics: serde_json::Value,
144    },
145
146    /// Python job result.
147    JobResult {
148        worker_id: WorkerId,
149        job_id: String,
150        success: bool,
151        metrics: serde_json::Value,
152        output: String,
153        duration_ms: u64,
154    },
155}
156
157/// A Python pipeline job: source files + requirements for isolated execution.
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct PythonPipelineJob {
160    pub job_id: String,
161    pub pipeline_id: String,
162    pub investigation_id: String,
163    /// Source files: path → content
164    pub files: Vec<PipelineFile>,
165    /// pip requirements (content of requirements.txt)
166    pub requirements: String,
167    /// Entry point: which file/function to execute
168    pub entry_point: String,
169    /// Input data (JSON-serialized)
170    pub input_data: Option<serde_json::Value>,
171    /// Extra parameters
172    pub params: serde_json::Value,
173}
174
175/// A source file in a pipeline job.
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct PipelineFile {
178    pub path: String,
179    pub content: String,
180}
181
182/// Messages from Coordinator → Worker.
183#[derive(Debug, Clone, Serialize, Deserialize)]
184#[serde(tag = "type")]
185pub enum CoordinatorToWorker {
186    /// Accept worker registration.
187    Registered { worker_id: WorkerId },
188
189    /// Assign a native Soma plan for execution.
190    AssignPlan { plan: SerializedPlan },
191
192    /// Assign a Python pipeline job (with environment isolation).
193    AssignPythonJob { job: PythonPipelineJob },
194
195    /// Cancel a running plan/job.
196    CancelPlan { plan_id: PlanId },
197
198    /// Request current status.
199    StatusRequest,
200
201    /// Ping for keepalive.
202    Ping,
203}
204
205/// Result of a plan execution.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207#[serde(tag = "status")]
208pub enum PlanResult {
209    Success { output: Value, duration_ms: u64 },
210    Failed { error: String, duration_ms: u64 },
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use somatize_core::event::PlanSummary;
217
218    #[test]
219    fn capabilities_serde() {
220        let caps = Capabilities {
221            cpu_cores: 8,
222            ram_bytes: 32 * 1024 * 1024 * 1024,
223            gpus: vec![GpuInfo {
224                name: "A100".into(),
225                memory_bytes: 80 * 1024 * 1024 * 1024,
226            }],
227            python_envs: vec!["py310".into(), "py311".into()],
228            tags: vec!["gpu".into(), "training".into()],
229        };
230        let json = serde_json::to_string(&caps).unwrap();
231        let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
232        assert_eq!(deserialized.cpu_cores, 8);
233        assert_eq!(deserialized.gpus.len(), 1);
234        assert_eq!(deserialized.tags, vec!["gpu", "training"]);
235    }
236
237    #[test]
238    fn worker_message_serde() {
239        let msg = WorkerToCoordinator::Register {
240            worker_id: "worker_01".into(),
241            capabilities: Capabilities {
242                cpu_cores: 4,
243                ram_bytes: 16_000_000_000,
244                gpus: vec![],
245                python_envs: vec![],
246                tags: vec!["cpu".into()],
247            },
248        };
249        let json = serde_json::to_string(&msg).unwrap();
250        assert!(json.contains("Register"));
251        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
252        if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
253            assert_eq!(worker_id, "worker_01");
254        } else {
255            panic!("wrong variant");
256        }
257    }
258
259    #[test]
260    fn coordinator_message_serde() {
261        let msg = CoordinatorToWorker::AssignPlan {
262            plan: SerializedPlan {
263                plan_id: "plan_001".into(),
264                plan: ExecutionPlan::Execute {
265                    node_id: "train".into(),
266                },
267                input: Some(InputSource::Inline {
268                    value: Value::tensor(vec![1.0, 2.0], vec![2]),
269                }),
270                filters: vec![],
271                metadata: serde_json::json!({"experiment": "test"}),
272            },
273        };
274        let json = serde_json::to_string(&msg).unwrap();
275        let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
276        assert!(matches!(
277            deserialized,
278            CoordinatorToWorker::AssignPlan { .. }
279        ));
280    }
281
282    #[test]
283    fn plan_result_serde() {
284        let success = PlanResult::Success {
285            output: Value::tensor(vec![0.95], vec![1]),
286            duration_ms: 1234,
287        };
288        let json = serde_json::to_string(&success).unwrap();
289        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
290        assert!(matches!(deserialized, PlanResult::Success { .. }));
291
292        let failed = PlanResult::Failed {
293            error: "OOM".into(),
294            duration_ms: 500,
295        };
296        let json = serde_json::to_string(&failed).unwrap();
297        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
298        assert!(matches!(deserialized, PlanResult::Failed { .. }));
299    }
300
301    #[test]
302    fn event_message_serde() {
303        let msg = WorkerToCoordinator::Event {
304            worker_id: "w1".into(),
305            plan_id: "p1".into(),
306            event: Event::RunStarted {
307                run_id: "r1".into(),
308                plan_summary: PlanSummary {
309                    total_nodes: 3,
310                    cached_nodes: 1,
311                    parallel_branches: 0,
312                },
313            },
314        };
315        let json = serde_json::to_string(&msg).unwrap();
316        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
317        assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
318    }
319
320    #[test]
321    fn heartbeat_serde() {
322        let msg = WorkerToCoordinator::Heartbeat {
323            worker_id: "w1".into(),
324            load: LoadMetrics {
325                cpu_usage: 0.45,
326                memory_usage: 0.72,
327                gpu_usage: vec![0.88],
328                active_plans: 2,
329                queue_depth: 5,
330                timestamp: Utc::now(),
331            },
332        };
333        let json = serde_json::to_string(&msg).unwrap();
334        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
335        if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
336            assert!(load.cpu_usage > 0.0);
337            assert_eq!(load.active_plans, 2);
338        }
339    }
340}