Skip to main content

somatize_worker/
protocol.rs

1//! Wire protocol for coordinator ↔ worker communication.
2//!
3//! Defines message types for plan assignment, results, heartbeats,
4//! Python job management, and worker capabilities.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::DataRef;
11use somatize_core::value::Value;
12
13/// Unique worker identifier.
14pub type WorkerId = String;
15
16/// Unique plan execution identifier.
17pub type PlanId = String;
18
19/// Hardware and software capabilities of a worker.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22    /// Number of CPU cores.
23    pub cpu_cores: usize,
24    /// Total RAM in bytes.
25    pub ram_bytes: u64,
26    /// GPU information.
27    pub gpus: Vec<GpuInfo>,
28    /// Available Python environments.
29    pub python_envs: Vec<String>,
30    /// User-defined tags for routing (e.g. "gpu", "training", "inference").
31    pub tags: Vec<String>,
32}
33
34/// GPU hardware info.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37    pub name: String,
38    pub memory_bytes: u64,
39}
40
41/// Current load metrics reported by a worker.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44    pub cpu_usage: f32,
45    pub memory_usage: f32,
46    pub gpu_usage: Vec<f32>,
47    pub active_plans: usize,
48    pub queue_depth: usize,
49    pub timestamp: DateTime<Utc>,
50}
51
52/// How input data is provided to a worker.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57    /// Data embedded directly in the message (small payloads).
58    Inline { value: Value },
59    /// Data referenced in a remote store (large payloads).
60    Reference { data_ref: DataRef },
61}
62
63/// A serialized plan ready for remote execution.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct SerializedPlan {
66    pub plan_id: PlanId,
67    pub plan: ExecutionPlan,
68    /// Input data — inline for small values, DataRef for large ones.
69    pub input: Option<InputSource>,
70    pub metadata: serde_json::Value,
71}
72
73/// Messages from Worker → Coordinator.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(tag = "type")]
76pub enum WorkerToCoordinator {
77    /// Worker announces itself.
78    Register {
79        worker_id: WorkerId,
80        capabilities: Capabilities,
81    },
82
83    /// Periodic health check.
84    Heartbeat {
85        worker_id: WorkerId,
86        load: LoadMetrics,
87    },
88
89    /// Execution event streamed back in real-time.
90    Event {
91        worker_id: WorkerId,
92        plan_id: PlanId,
93        event: Event,
94    },
95
96    /// Plan execution completed.
97    PlanResult {
98        worker_id: WorkerId,
99        plan_id: PlanId,
100        result: PlanResult,
101    },
102
103    /// Python job progress update.
104    JobProgress {
105        worker_id: WorkerId,
106        job_id: String,
107        phase: String,
108        step: u32,
109        total: u32,
110        metrics: serde_json::Value,
111    },
112
113    /// Python job result.
114    JobResult {
115        worker_id: WorkerId,
116        job_id: String,
117        success: bool,
118        metrics: serde_json::Value,
119        output: String,
120        duration_ms: u64,
121    },
122}
123
124/// A Python pipeline job: source files + requirements for isolated execution.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct PythonPipelineJob {
127    pub job_id: String,
128    pub pipeline_id: String,
129    pub investigation_id: String,
130    /// Source files: path → content
131    pub files: Vec<PipelineFile>,
132    /// pip requirements (content of requirements.txt)
133    pub requirements: String,
134    /// Entry point: which file/function to execute
135    pub entry_point: String,
136    /// Input data (JSON-serialized)
137    pub input_data: Option<serde_json::Value>,
138    /// Extra parameters
139    pub params: serde_json::Value,
140}
141
142/// A source file in a pipeline job.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct PipelineFile {
145    pub path: String,
146    pub content: String,
147}
148
149/// Messages from Coordinator → Worker.
150#[derive(Debug, Clone, Serialize, Deserialize)]
151#[serde(tag = "type")]
152pub enum CoordinatorToWorker {
153    /// Accept worker registration.
154    Registered { worker_id: WorkerId },
155
156    /// Assign a native Soma plan for execution.
157    AssignPlan { plan: SerializedPlan },
158
159    /// Assign a Python pipeline job (with environment isolation).
160    AssignPythonJob { job: PythonPipelineJob },
161
162    /// Cancel a running plan/job.
163    CancelPlan { plan_id: PlanId },
164
165    /// Request current status.
166    StatusRequest,
167
168    /// Ping for keepalive.
169    Ping,
170}
171
172/// Result of a plan execution.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174#[serde(tag = "status")]
175pub enum PlanResult {
176    Success { output: Value, duration_ms: u64 },
177    Failed { error: String, duration_ms: u64 },
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use somatize_core::event::PlanSummary;
184
185    #[test]
186    fn capabilities_serde() {
187        let caps = Capabilities {
188            cpu_cores: 8,
189            ram_bytes: 32 * 1024 * 1024 * 1024,
190            gpus: vec![GpuInfo {
191                name: "A100".into(),
192                memory_bytes: 80 * 1024 * 1024 * 1024,
193            }],
194            python_envs: vec!["py310".into(), "py311".into()],
195            tags: vec!["gpu".into(), "training".into()],
196        };
197        let json = serde_json::to_string(&caps).unwrap();
198        let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
199        assert_eq!(deserialized.cpu_cores, 8);
200        assert_eq!(deserialized.gpus.len(), 1);
201        assert_eq!(deserialized.tags, vec!["gpu", "training"]);
202    }
203
204    #[test]
205    fn worker_message_serde() {
206        let msg = WorkerToCoordinator::Register {
207            worker_id: "worker_01".into(),
208            capabilities: Capabilities {
209                cpu_cores: 4,
210                ram_bytes: 16_000_000_000,
211                gpus: vec![],
212                python_envs: vec![],
213                tags: vec!["cpu".into()],
214            },
215        };
216        let json = serde_json::to_string(&msg).unwrap();
217        assert!(json.contains("Register"));
218        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
219        if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
220            assert_eq!(worker_id, "worker_01");
221        } else {
222            panic!("wrong variant");
223        }
224    }
225
226    #[test]
227    fn coordinator_message_serde() {
228        let msg = CoordinatorToWorker::AssignPlan {
229            plan: SerializedPlan {
230                plan_id: "plan_001".into(),
231                plan: ExecutionPlan::Execute {
232                    node_id: "train".into(),
233                },
234                input: Some(InputSource::Inline {
235                    value: Value::tensor(vec![1.0, 2.0], vec![2]),
236                }),
237                metadata: serde_json::json!({"experiment": "test"}),
238            },
239        };
240        let json = serde_json::to_string(&msg).unwrap();
241        let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
242        assert!(matches!(
243            deserialized,
244            CoordinatorToWorker::AssignPlan { .. }
245        ));
246    }
247
248    #[test]
249    fn plan_result_serde() {
250        let success = PlanResult::Success {
251            output: Value::tensor(vec![0.95], vec![1]),
252            duration_ms: 1234,
253        };
254        let json = serde_json::to_string(&success).unwrap();
255        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
256        assert!(matches!(deserialized, PlanResult::Success { .. }));
257
258        let failed = PlanResult::Failed {
259            error: "OOM".into(),
260            duration_ms: 500,
261        };
262        let json = serde_json::to_string(&failed).unwrap();
263        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
264        assert!(matches!(deserialized, PlanResult::Failed { .. }));
265    }
266
267    #[test]
268    fn event_message_serde() {
269        let msg = WorkerToCoordinator::Event {
270            worker_id: "w1".into(),
271            plan_id: "p1".into(),
272            event: Event::RunStarted {
273                run_id: "r1".into(),
274                plan_summary: PlanSummary {
275                    total_nodes: 3,
276                    cached_nodes: 1,
277                    parallel_branches: 0,
278                },
279            },
280        };
281        let json = serde_json::to_string(&msg).unwrap();
282        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
283        assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
284    }
285
286    #[test]
287    fn heartbeat_serde() {
288        let msg = WorkerToCoordinator::Heartbeat {
289            worker_id: "w1".into(),
290            load: LoadMetrics {
291                cpu_usage: 0.45,
292                memory_usage: 0.72,
293                gpu_usage: vec![0.88],
294                active_plans: 2,
295                queue_depth: 5,
296                timestamp: Utc::now(),
297            },
298        };
299        let json = serde_json::to_string(&msg).unwrap();
300        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
301        if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
302            assert!(load.cpu_usage > 0.0);
303            assert_eq!(load.active_plans, 2);
304        }
305    }
306}