1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::DataRef;
11use somatize_core::value::Value;
12
13pub type WorkerId = String;
15
16pub type PlanId = String;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22 pub cpu_cores: usize,
24 pub ram_bytes: u64,
26 pub gpus: Vec<GpuInfo>,
28 pub python_envs: Vec<String>,
30 pub tags: Vec<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37 pub name: String,
38 pub memory_bytes: u64,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44 pub cpu_usage: f32,
45 pub memory_usage: f32,
46 pub gpu_usage: Vec<f32>,
47 pub active_plans: usize,
48 pub queue_depth: usize,
49 pub timestamp: DateTime<Utc>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57 Inline { value: Value },
59 Reference { data_ref: DataRef },
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SerializedFilter {
69 pub node_id: String,
71 #[serde(with = "base64_bytes")]
73 pub pickled_filter: Vec<u8>,
74 pub state: Option<Value>,
76 #[serde(default)]
78 pub requirements: Vec<String>,
79}
80
81mod base64_bytes {
83 use base64::engine::{Engine, general_purpose::STANDARD};
84 use serde::{Deserialize, Deserializer, Serialize, Serializer};
85
86 pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
87 STANDARD.encode(bytes).serialize(s)
88 }
89
90 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
91 let s = String::deserialize(d)?;
92 STANDARD.decode(s).map_err(serde::de::Error::custom)
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize, Default)]
98#[non_exhaustive]
99pub enum ExecutionMode {
100 Fit {
102 y: Option<Value>,
104 },
105 #[default]
107 Forward,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct SerializedPlan {
113 pub plan_id: PlanId,
114 pub plan: ExecutionPlan,
115 pub input: Option<InputSource>,
117 #[serde(default)]
119 pub filters: Vec<SerializedFilter>,
120 #[serde(default)]
122 pub mode: ExecutionMode,
123 pub metadata: serde_json::Value,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128#[serde(tag = "type")]
129pub enum WorkerToCoordinator {
130 Register {
132 worker_id: WorkerId,
133 capabilities: Capabilities,
134 },
135
136 Heartbeat {
138 worker_id: WorkerId,
139 load: LoadMetrics,
140 },
141
142 Event {
144 worker_id: WorkerId,
145 plan_id: PlanId,
146 event: Event,
147 },
148
149 PlanResult {
151 worker_id: WorkerId,
152 plan_id: PlanId,
153 result: PlanResult,
154 },
155
156 JobProgress {
158 worker_id: WorkerId,
159 job_id: String,
160 phase: String,
161 step: u32,
162 total: u32,
163 metrics: serde_json::Value,
164 },
165
166 JobResult {
168 worker_id: WorkerId,
169 job_id: String,
170 success: bool,
171 metrics: serde_json::Value,
172 output: String,
173 duration_ms: u64,
174 },
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct PythonPipelineJob {
180 pub job_id: String,
181 pub pipeline_id: String,
182 pub investigation_id: String,
183 pub files: Vec<PipelineFile>,
185 pub requirements: String,
187 pub entry_point: String,
189 pub input_data: Option<serde_json::Value>,
191 pub params: serde_json::Value,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct PipelineFile {
198 pub path: String,
199 pub content: String,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
204#[serde(tag = "type")]
205pub enum CoordinatorToWorker {
206 Registered { worker_id: WorkerId },
208
209 AssignPlan { plan: SerializedPlan },
211
212 AssignPythonJob { job: PythonPipelineJob },
214
215 CancelPlan { plan_id: PlanId },
217
218 StatusRequest,
220
221 Ping,
223
224 Shutdown { reason: String },
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230#[serde(tag = "status")]
231pub enum PlanResult {
232 Success {
233 output: Value,
234 duration_ms: u64,
235 #[serde(default)]
238 states: std::collections::HashMap<String, Value>,
239 },
240 Failed {
241 error: String,
242 duration_ms: u64,
243 },
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
252#[serde(tag = "type")]
253#[non_exhaustive]
254pub enum StreamMessage {
255 StreamBegin {
257 stream_id: String,
258 plan_id: PlanId,
259 total_chunks: Option<usize>,
261 plan: Box<SerializedPlan>,
263 },
264 ChunkData {
266 stream_id: String,
267 chunk_index: usize,
268 value: Value,
269 },
270 StreamEnd { stream_id: String },
272 ChunkResult {
274 stream_id: String,
275 chunk_index: usize,
276 value: Value,
277 },
278 StreamComplete {
280 stream_id: String,
281 result: PlanResult,
282 },
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use somatize_core::event::PlanSummary;
289
290 #[test]
291 fn capabilities_serde() {
292 let caps = Capabilities {
293 cpu_cores: 8,
294 ram_bytes: 32 * 1024 * 1024 * 1024,
295 gpus: vec![GpuInfo {
296 name: "A100".into(),
297 memory_bytes: 80 * 1024 * 1024 * 1024,
298 }],
299 python_envs: vec!["py310".into(), "py311".into()],
300 tags: vec!["gpu".into(), "training".into()],
301 };
302 let json = serde_json::to_string(&caps).unwrap();
303 let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
304 assert_eq!(deserialized.cpu_cores, 8);
305 assert_eq!(deserialized.gpus.len(), 1);
306 assert_eq!(deserialized.tags, vec!["gpu", "training"]);
307 }
308
309 #[test]
310 fn worker_message_serde() {
311 let msg = WorkerToCoordinator::Register {
312 worker_id: "worker_01".into(),
313 capabilities: Capabilities {
314 cpu_cores: 4,
315 ram_bytes: 16_000_000_000,
316 gpus: vec![],
317 python_envs: vec![],
318 tags: vec!["cpu".into()],
319 },
320 };
321 let json = serde_json::to_string(&msg).unwrap();
322 assert!(json.contains("Register"));
323 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
324 if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
325 assert_eq!(worker_id, "worker_01");
326 } else {
327 panic!("wrong variant");
328 }
329 }
330
331 #[test]
332 fn coordinator_message_serde() {
333 let msg = CoordinatorToWorker::AssignPlan {
334 plan: SerializedPlan {
335 plan_id: "plan_001".into(),
336 plan: ExecutionPlan::Execute {
337 node_id: "train".into(),
338 },
339 input: Some(InputSource::Inline {
340 value: Value::tensor(vec![1.0, 2.0], vec![2]),
341 }),
342 filters: vec![],
343 mode: ExecutionMode::default(),
344 metadata: serde_json::json!({"experiment": "test"}),
345 },
346 };
347 let json = serde_json::to_string(&msg).unwrap();
348 let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
349 assert!(matches!(
350 deserialized,
351 CoordinatorToWorker::AssignPlan { .. }
352 ));
353 }
354
355 #[test]
356 fn plan_result_serde() {
357 let success = PlanResult::Success {
358 output: Value::tensor(vec![0.95], vec![1]),
359 duration_ms: 1234,
360 states: std::collections::HashMap::new(),
361 };
362 let json = serde_json::to_string(&success).unwrap();
363 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
364 assert!(matches!(deserialized, PlanResult::Success { .. }));
365
366 let failed = PlanResult::Failed {
367 error: "OOM".into(),
368 duration_ms: 500,
369 };
370 let json = serde_json::to_string(&failed).unwrap();
371 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
372 assert!(matches!(deserialized, PlanResult::Failed { .. }));
373 }
374
375 #[test]
376 fn event_message_serde() {
377 let msg = WorkerToCoordinator::Event {
378 worker_id: "w1".into(),
379 plan_id: "p1".into(),
380 event: Event::RunStarted {
381 run_id: "r1".into(),
382 plan_summary: PlanSummary {
383 total_nodes: 3,
384 cached_nodes: 1,
385 parallel_branches: 0,
386 },
387 },
388 };
389 let json = serde_json::to_string(&msg).unwrap();
390 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
391 assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
392 }
393
394 #[test]
395 fn heartbeat_serde() {
396 let msg = WorkerToCoordinator::Heartbeat {
397 worker_id: "w1".into(),
398 load: LoadMetrics {
399 cpu_usage: 0.45,
400 memory_usage: 0.72,
401 gpu_usage: vec![0.88],
402 active_plans: 2,
403 queue_depth: 5,
404 timestamp: Utc::now(),
405 },
406 };
407 let json = serde_json::to_string(&msg).unwrap();
408 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
409 if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
410 assert!(load.cpu_usage > 0.0);
411 assert_eq!(load.active_plans, 2);
412 }
413 }
414}