1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::DataRef;
11use somatize_core::value::Value;
12
13pub type WorkerId = String;
15
16pub type PlanId = String;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22 pub cpu_cores: usize,
24 pub ram_bytes: u64,
26 pub gpus: Vec<GpuInfo>,
28 pub python_envs: Vec<String>,
30 pub tags: Vec<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37 pub name: String,
38 pub memory_bytes: u64,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44 pub cpu_usage: f32,
45 pub memory_usage: f32,
46 pub gpu_usage: Vec<f32>,
47 pub active_plans: usize,
48 pub queue_depth: usize,
49 pub timestamp: DateTime<Utc>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57 Inline { value: Value },
59 Reference { data_ref: DataRef },
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SerializedFilter {
69 pub node_id: String,
71 #[serde(with = "base64_bytes")]
73 pub pickled_filter: Vec<u8>,
74 pub state: Option<Value>,
76 #[serde(default)]
78 pub requirements: Vec<String>,
79 #[serde(default)]
81 pub trainable: bool,
82}
83
84mod base64_bytes {
86 use base64::engine::{Engine, general_purpose::STANDARD};
87 use serde::{Deserialize, Deserializer, Serialize, Serializer};
88
89 pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
90 STANDARD.encode(bytes).serialize(s)
91 }
92
93 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
94 let s = String::deserialize(d)?;
95 STANDARD.decode(s).map_err(serde::de::Error::custom)
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, Default)]
101#[non_exhaustive]
102pub enum ExecutionMode {
103 Fit {
105 y: Option<Value>,
107 },
108 #[default]
110 Forward,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct SerializedPlan {
116 pub plan_id: PlanId,
117 pub plan: ExecutionPlan,
118 pub input: Option<InputSource>,
120 #[serde(default)]
122 pub filters: Vec<SerializedFilter>,
123 #[serde(default)]
125 pub mode: ExecutionMode,
126 pub metadata: serde_json::Value,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131#[serde(tag = "type")]
132pub enum WorkerToCoordinator {
133 Register {
135 worker_id: WorkerId,
136 capabilities: Capabilities,
137 },
138
139 Heartbeat {
141 worker_id: WorkerId,
142 load: LoadMetrics,
143 },
144
145 Event {
147 worker_id: WorkerId,
148 plan_id: PlanId,
149 event: Event,
150 },
151
152 PlanResult {
154 worker_id: WorkerId,
155 plan_id: PlanId,
156 result: PlanResult,
157 },
158
159 JobProgress {
161 worker_id: WorkerId,
162 job_id: String,
163 phase: String,
164 step: u32,
165 total: u32,
166 metrics: serde_json::Value,
167 },
168
169 JobResult {
171 worker_id: WorkerId,
172 job_id: String,
173 success: bool,
174 metrics: serde_json::Value,
175 output: String,
176 duration_ms: u64,
177 },
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct PythonPipelineJob {
183 pub job_id: String,
184 pub pipeline_id: String,
185 pub investigation_id: String,
186 pub files: Vec<PipelineFile>,
188 pub requirements: String,
190 pub entry_point: String,
192 pub input_data: Option<serde_json::Value>,
194 pub params: serde_json::Value,
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct PipelineFile {
201 pub path: String,
202 pub content: String,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207#[serde(tag = "type")]
208pub enum CoordinatorToWorker {
209 Registered { worker_id: WorkerId },
211
212 AssignPlan { plan: SerializedPlan },
214
215 AssignPythonJob { job: PythonPipelineJob },
217
218 CancelPlan { plan_id: PlanId },
220
221 StatusRequest,
223
224 Ping,
226
227 Shutdown { reason: String },
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233#[serde(tag = "status")]
234pub enum PlanResult {
235 Success {
236 output: Value,
237 duration_ms: u64,
238 #[serde(default)]
241 states: std::collections::HashMap<String, Value>,
242 },
243 Failed {
244 error: String,
245 duration_ms: u64,
246 },
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(tag = "type")]
256#[non_exhaustive]
257pub enum StreamMessage {
258 StreamBegin {
260 stream_id: String,
261 plan_id: PlanId,
262 total_chunks: Option<usize>,
264 plan: Box<SerializedPlan>,
266 },
267 ChunkData {
269 stream_id: String,
270 chunk_index: usize,
271 value: Value,
272 },
273 StreamEnd { stream_id: String },
275 ChunkResult {
277 stream_id: String,
278 chunk_index: usize,
279 value: Value,
280 },
281 StreamComplete {
283 stream_id: String,
284 result: PlanResult,
285 },
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use somatize_core::event::PlanSummary;
292
293 #[test]
294 fn capabilities_serde() {
295 let caps = Capabilities {
296 cpu_cores: 8,
297 ram_bytes: 32 * 1024 * 1024 * 1024,
298 gpus: vec![GpuInfo {
299 name: "A100".into(),
300 memory_bytes: 80 * 1024 * 1024 * 1024,
301 }],
302 python_envs: vec!["py310".into(), "py311".into()],
303 tags: vec!["gpu".into(), "training".into()],
304 };
305 let json = serde_json::to_string(&caps).unwrap();
306 let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
307 assert_eq!(deserialized.cpu_cores, 8);
308 assert_eq!(deserialized.gpus.len(), 1);
309 assert_eq!(deserialized.tags, vec!["gpu", "training"]);
310 }
311
312 #[test]
313 fn worker_message_serde() {
314 let msg = WorkerToCoordinator::Register {
315 worker_id: "worker_01".into(),
316 capabilities: Capabilities {
317 cpu_cores: 4,
318 ram_bytes: 16_000_000_000,
319 gpus: vec![],
320 python_envs: vec![],
321 tags: vec!["cpu".into()],
322 },
323 };
324 let json = serde_json::to_string(&msg).unwrap();
325 assert!(json.contains("Register"));
326 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
327 if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
328 assert_eq!(worker_id, "worker_01");
329 } else {
330 panic!("wrong variant");
331 }
332 }
333
334 #[test]
335 fn coordinator_message_serde() {
336 let msg = CoordinatorToWorker::AssignPlan {
337 plan: SerializedPlan {
338 plan_id: "plan_001".into(),
339 plan: ExecutionPlan::Execute {
340 node_id: "train".into(),
341 },
342 input: Some(InputSource::Inline {
343 value: Value::tensor(vec![1.0, 2.0], vec![2]),
344 }),
345 filters: vec![],
346 mode: ExecutionMode::default(),
347 metadata: serde_json::json!({"experiment": "test"}),
348 },
349 };
350 let json = serde_json::to_string(&msg).unwrap();
351 let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
352 assert!(matches!(
353 deserialized,
354 CoordinatorToWorker::AssignPlan { .. }
355 ));
356 }
357
358 #[test]
359 fn plan_result_serde() {
360 let success = PlanResult::Success {
361 output: Value::tensor(vec![0.95], vec![1]),
362 duration_ms: 1234,
363 states: std::collections::HashMap::new(),
364 };
365 let json = serde_json::to_string(&success).unwrap();
366 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
367 assert!(matches!(deserialized, PlanResult::Success { .. }));
368
369 let failed = PlanResult::Failed {
370 error: "OOM".into(),
371 duration_ms: 500,
372 };
373 let json = serde_json::to_string(&failed).unwrap();
374 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
375 assert!(matches!(deserialized, PlanResult::Failed { .. }));
376 }
377
378 #[test]
379 fn event_message_serde() {
380 let msg = WorkerToCoordinator::Event {
381 worker_id: "w1".into(),
382 plan_id: "p1".into(),
383 event: Event::RunStarted {
384 run_id: "r1".into(),
385 plan_summary: PlanSummary {
386 total_nodes: 3,
387 cached_nodes: 1,
388 parallel_branches: 0,
389 },
390 },
391 };
392 let json = serde_json::to_string(&msg).unwrap();
393 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
394 assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
395 }
396
397 #[test]
398 fn heartbeat_serde() {
399 let msg = WorkerToCoordinator::Heartbeat {
400 worker_id: "w1".into(),
401 load: LoadMetrics {
402 cpu_usage: 0.45,
403 memory_usage: 0.72,
404 gpu_usage: vec![0.88],
405 active_plans: 2,
406 queue_depth: 5,
407 timestamp: Utc::now(),
408 },
409 };
410 let json = serde_json::to_string(&msg).unwrap();
411 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
412 if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
413 assert!(load.cpu_usage > 0.0);
414 assert_eq!(load.active_plans, 2);
415 }
416 }
417}