1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::DataRef;
11use somatize_core::value::Value;
12
13pub type WorkerId = String;
15
16pub type PlanId = String;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22 pub cpu_cores: usize,
24 pub ram_bytes: u64,
26 pub gpus: Vec<GpuInfo>,
28 pub python_envs: Vec<String>,
30 pub tags: Vec<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37 pub name: String,
38 pub memory_bytes: u64,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44 pub cpu_usage: f32,
45 pub memory_usage: f32,
46 pub gpu_usage: Vec<f32>,
47 pub active_plans: usize,
48 pub queue_depth: usize,
49 pub timestamp: DateTime<Utc>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57 Inline { value: Value },
59 Reference { data_ref: DataRef },
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct SerializedPlan {
66 pub plan_id: PlanId,
67 pub plan: ExecutionPlan,
68 pub input: Option<InputSource>,
70 pub metadata: serde_json::Value,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(tag = "type")]
76pub enum WorkerToCoordinator {
77 Register {
79 worker_id: WorkerId,
80 capabilities: Capabilities,
81 },
82
83 Heartbeat {
85 worker_id: WorkerId,
86 load: LoadMetrics,
87 },
88
89 Event {
91 worker_id: WorkerId,
92 plan_id: PlanId,
93 event: Event,
94 },
95
96 PlanResult {
98 worker_id: WorkerId,
99 plan_id: PlanId,
100 result: PlanResult,
101 },
102
103 JobProgress {
105 worker_id: WorkerId,
106 job_id: String,
107 phase: String,
108 step: u32,
109 total: u32,
110 metrics: serde_json::Value,
111 },
112
113 JobResult {
115 worker_id: WorkerId,
116 job_id: String,
117 success: bool,
118 metrics: serde_json::Value,
119 output: String,
120 duration_ms: u64,
121 },
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct PythonPipelineJob {
127 pub job_id: String,
128 pub pipeline_id: String,
129 pub investigation_id: String,
130 pub files: Vec<PipelineFile>,
132 pub requirements: String,
134 pub entry_point: String,
136 pub input_data: Option<serde_json::Value>,
138 pub params: serde_json::Value,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct PipelineFile {
145 pub path: String,
146 pub content: String,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151#[serde(tag = "type")]
152pub enum CoordinatorToWorker {
153 Registered { worker_id: WorkerId },
155
156 AssignPlan { plan: SerializedPlan },
158
159 AssignPythonJob { job: PythonPipelineJob },
161
162 CancelPlan { plan_id: PlanId },
164
165 StatusRequest,
167
168 Ping,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174#[serde(tag = "status")]
175pub enum PlanResult {
176 Success { output: Value, duration_ms: u64 },
177 Failed { error: String, duration_ms: u64 },
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use somatize_core::event::PlanSummary;
184
185 #[test]
186 fn capabilities_serde() {
187 let caps = Capabilities {
188 cpu_cores: 8,
189 ram_bytes: 32 * 1024 * 1024 * 1024,
190 gpus: vec![GpuInfo {
191 name: "A100".into(),
192 memory_bytes: 80 * 1024 * 1024 * 1024,
193 }],
194 python_envs: vec!["py310".into(), "py311".into()],
195 tags: vec!["gpu".into(), "training".into()],
196 };
197 let json = serde_json::to_string(&caps).unwrap();
198 let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
199 assert_eq!(deserialized.cpu_cores, 8);
200 assert_eq!(deserialized.gpus.len(), 1);
201 assert_eq!(deserialized.tags, vec!["gpu", "training"]);
202 }
203
204 #[test]
205 fn worker_message_serde() {
206 let msg = WorkerToCoordinator::Register {
207 worker_id: "worker_01".into(),
208 capabilities: Capabilities {
209 cpu_cores: 4,
210 ram_bytes: 16_000_000_000,
211 gpus: vec![],
212 python_envs: vec![],
213 tags: vec!["cpu".into()],
214 },
215 };
216 let json = serde_json::to_string(&msg).unwrap();
217 assert!(json.contains("Register"));
218 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
219 if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
220 assert_eq!(worker_id, "worker_01");
221 } else {
222 panic!("wrong variant");
223 }
224 }
225
226 #[test]
227 fn coordinator_message_serde() {
228 let msg = CoordinatorToWorker::AssignPlan {
229 plan: SerializedPlan {
230 plan_id: "plan_001".into(),
231 plan: ExecutionPlan::Execute {
232 node_id: "train".into(),
233 },
234 input: Some(InputSource::Inline {
235 value: Value::tensor(vec![1.0, 2.0], vec![2]),
236 }),
237 metadata: serde_json::json!({"experiment": "test"}),
238 },
239 };
240 let json = serde_json::to_string(&msg).unwrap();
241 let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
242 assert!(matches!(
243 deserialized,
244 CoordinatorToWorker::AssignPlan { .. }
245 ));
246 }
247
248 #[test]
249 fn plan_result_serde() {
250 let success = PlanResult::Success {
251 output: Value::tensor(vec![0.95], vec![1]),
252 duration_ms: 1234,
253 };
254 let json = serde_json::to_string(&success).unwrap();
255 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
256 assert!(matches!(deserialized, PlanResult::Success { .. }));
257
258 let failed = PlanResult::Failed {
259 error: "OOM".into(),
260 duration_ms: 500,
261 };
262 let json = serde_json::to_string(&failed).unwrap();
263 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
264 assert!(matches!(deserialized, PlanResult::Failed { .. }));
265 }
266
267 #[test]
268 fn event_message_serde() {
269 let msg = WorkerToCoordinator::Event {
270 worker_id: "w1".into(),
271 plan_id: "p1".into(),
272 event: Event::RunStarted {
273 run_id: "r1".into(),
274 plan_summary: PlanSummary {
275 total_nodes: 3,
276 cached_nodes: 1,
277 parallel_branches: 0,
278 },
279 },
280 };
281 let json = serde_json::to_string(&msg).unwrap();
282 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
283 assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
284 }
285
286 #[test]
287 fn heartbeat_serde() {
288 let msg = WorkerToCoordinator::Heartbeat {
289 worker_id: "w1".into(),
290 load: LoadMetrics {
291 cpu_usage: 0.45,
292 memory_usage: 0.72,
293 gpu_usage: vec![0.88],
294 active_plans: 2,
295 queue_depth: 5,
296 timestamp: Utc::now(),
297 },
298 };
299 let json = serde_json::to_string(&msg).unwrap();
300 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
301 if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
302 assert!(load.cpu_usage > 0.0);
303 assert_eq!(load.active_plans, 2);
304 }
305 }
306}