1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::DataRef;
11use somatize_core::value::Value;
12
13pub type WorkerId = String;
15
16pub type PlanId = String;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22 pub cpu_cores: usize,
24 pub ram_bytes: u64,
26 pub gpus: Vec<GpuInfo>,
28 pub python_envs: Vec<String>,
30 pub tags: Vec<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37 pub name: String,
38 pub memory_bytes: u64,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44 pub cpu_usage: f32,
45 pub memory_usage: f32,
46 pub gpu_usage: Vec<f32>,
47 pub active_plans: usize,
48 pub queue_depth: usize,
49 pub timestamp: DateTime<Utc>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57 Inline { value: Value },
59 Reference { data_ref: DataRef },
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SerializedFilter {
69 pub node_id: String,
71 #[serde(with = "base64_bytes")]
73 pub pickled_filter: Vec<u8>,
74 pub state: Option<Value>,
76}
77
78mod base64_bytes {
80 use base64::engine::{Engine, general_purpose::STANDARD};
81 use serde::{Deserialize, Deserializer, Serialize, Serializer};
82
83 pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
84 STANDARD.encode(bytes).serialize(s)
85 }
86
87 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
88 let s = String::deserialize(d)?;
89 STANDARD.decode(s).map_err(serde::de::Error::custom)
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct SerializedPlan {
96 pub plan_id: PlanId,
97 pub plan: ExecutionPlan,
98 pub input: Option<InputSource>,
100 #[serde(default)]
102 pub filters: Vec<SerializedFilter>,
103 pub metadata: serde_json::Value,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108#[serde(tag = "type")]
109pub enum WorkerToCoordinator {
110 Register {
112 worker_id: WorkerId,
113 capabilities: Capabilities,
114 },
115
116 Heartbeat {
118 worker_id: WorkerId,
119 load: LoadMetrics,
120 },
121
122 Event {
124 worker_id: WorkerId,
125 plan_id: PlanId,
126 event: Event,
127 },
128
129 PlanResult {
131 worker_id: WorkerId,
132 plan_id: PlanId,
133 result: PlanResult,
134 },
135
136 JobProgress {
138 worker_id: WorkerId,
139 job_id: String,
140 phase: String,
141 step: u32,
142 total: u32,
143 metrics: serde_json::Value,
144 },
145
146 JobResult {
148 worker_id: WorkerId,
149 job_id: String,
150 success: bool,
151 metrics: serde_json::Value,
152 output: String,
153 duration_ms: u64,
154 },
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct PythonPipelineJob {
160 pub job_id: String,
161 pub pipeline_id: String,
162 pub investigation_id: String,
163 pub files: Vec<PipelineFile>,
165 pub requirements: String,
167 pub entry_point: String,
169 pub input_data: Option<serde_json::Value>,
171 pub params: serde_json::Value,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct PipelineFile {
178 pub path: String,
179 pub content: String,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184#[serde(tag = "type")]
185pub enum CoordinatorToWorker {
186 Registered { worker_id: WorkerId },
188
189 AssignPlan { plan: SerializedPlan },
191
192 AssignPythonJob { job: PythonPipelineJob },
194
195 CancelPlan { plan_id: PlanId },
197
198 StatusRequest,
200
201 Ping,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207#[serde(tag = "status")]
208pub enum PlanResult {
209 Success { output: Value, duration_ms: u64 },
210 Failed { error: String, duration_ms: u64 },
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use somatize_core::event::PlanSummary;
217
218 #[test]
219 fn capabilities_serde() {
220 let caps = Capabilities {
221 cpu_cores: 8,
222 ram_bytes: 32 * 1024 * 1024 * 1024,
223 gpus: vec![GpuInfo {
224 name: "A100".into(),
225 memory_bytes: 80 * 1024 * 1024 * 1024,
226 }],
227 python_envs: vec!["py310".into(), "py311".into()],
228 tags: vec!["gpu".into(), "training".into()],
229 };
230 let json = serde_json::to_string(&caps).unwrap();
231 let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
232 assert_eq!(deserialized.cpu_cores, 8);
233 assert_eq!(deserialized.gpus.len(), 1);
234 assert_eq!(deserialized.tags, vec!["gpu", "training"]);
235 }
236
237 #[test]
238 fn worker_message_serde() {
239 let msg = WorkerToCoordinator::Register {
240 worker_id: "worker_01".into(),
241 capabilities: Capabilities {
242 cpu_cores: 4,
243 ram_bytes: 16_000_000_000,
244 gpus: vec![],
245 python_envs: vec![],
246 tags: vec!["cpu".into()],
247 },
248 };
249 let json = serde_json::to_string(&msg).unwrap();
250 assert!(json.contains("Register"));
251 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
252 if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
253 assert_eq!(worker_id, "worker_01");
254 } else {
255 panic!("wrong variant");
256 }
257 }
258
259 #[test]
260 fn coordinator_message_serde() {
261 let msg = CoordinatorToWorker::AssignPlan {
262 plan: SerializedPlan {
263 plan_id: "plan_001".into(),
264 plan: ExecutionPlan::Execute {
265 node_id: "train".into(),
266 },
267 input: Some(InputSource::Inline {
268 value: Value::tensor(vec![1.0, 2.0], vec![2]),
269 }),
270 filters: vec![],
271 metadata: serde_json::json!({"experiment": "test"}),
272 },
273 };
274 let json = serde_json::to_string(&msg).unwrap();
275 let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
276 assert!(matches!(
277 deserialized,
278 CoordinatorToWorker::AssignPlan { .. }
279 ));
280 }
281
282 #[test]
283 fn plan_result_serde() {
284 let success = PlanResult::Success {
285 output: Value::tensor(vec![0.95], vec![1]),
286 duration_ms: 1234,
287 };
288 let json = serde_json::to_string(&success).unwrap();
289 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
290 assert!(matches!(deserialized, PlanResult::Success { .. }));
291
292 let failed = PlanResult::Failed {
293 error: "OOM".into(),
294 duration_ms: 500,
295 };
296 let json = serde_json::to_string(&failed).unwrap();
297 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
298 assert!(matches!(deserialized, PlanResult::Failed { .. }));
299 }
300
301 #[test]
302 fn event_message_serde() {
303 let msg = WorkerToCoordinator::Event {
304 worker_id: "w1".into(),
305 plan_id: "p1".into(),
306 event: Event::RunStarted {
307 run_id: "r1".into(),
308 plan_summary: PlanSummary {
309 total_nodes: 3,
310 cached_nodes: 1,
311 parallel_branches: 0,
312 },
313 },
314 };
315 let json = serde_json::to_string(&msg).unwrap();
316 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
317 assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
318 }
319
320 #[test]
321 fn heartbeat_serde() {
322 let msg = WorkerToCoordinator::Heartbeat {
323 worker_id: "w1".into(),
324 load: LoadMetrics {
325 cpu_usage: 0.45,
326 memory_usage: 0.72,
327 gpu_usage: vec![0.88],
328 active_plans: 2,
329 queue_depth: 5,
330 timestamp: Utc::now(),
331 },
332 };
333 let json = serde_json::to_string(&msg).unwrap();
334 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
335 if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
336 assert!(load.cpu_usage > 0.0);
337 assert_eq!(load.active_plans, 2);
338 }
339 }
340}