1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::{DataRef, DataStore};
11use somatize_core::value::Value;
12
13pub type WorkerId = String;
15
16pub type PlanId = String;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22 pub cpu_cores: usize,
24 pub ram_bytes: u64,
26 pub gpus: Vec<GpuInfo>,
28 pub python_envs: Vec<String>,
30 pub tags: Vec<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37 pub name: String,
38 pub memory_bytes: u64,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44 pub cpu_usage: f32,
45 pub memory_usage: f32,
46 pub gpu_usage: Vec<f32>,
47 pub active_plans: usize,
48 pub queue_depth: usize,
49 pub timestamp: DateTime<Utc>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57 Inline { value: Value },
59 Reference { data_ref: DataRef },
61}
62
63impl InputSource {
64 pub fn resolve(
67 &self,
68 data_store: Option<&dyn somatize_core::store::DataStore>,
69 temp_store: &somatize_core::store::LocalDataStore,
70 ) -> Value {
71 match self {
72 InputSource::Inline { value } => value.clone(),
73 InputSource::Reference { data_ref } => {
74 if let Some(store) = data_store
75 && let Ok(val) = store.get(data_ref)
76 {
77 return val;
78 }
79 temp_store.get(data_ref).unwrap_or_else(|e| {
80 tracing::warn!("Failed to resolve DataRef: {e}");
81 Value::Empty
82 })
83 }
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct SerializedFilter {
94 pub node_id: String,
96 #[serde(with = "base64_bytes")]
98 pub pickled_filter: Vec<u8>,
99 pub state: Option<Value>,
101 #[serde(default)]
103 pub requirements: Vec<String>,
104 #[serde(default)]
106 pub trainable: bool,
107}
108
109mod base64_bytes {
111 use base64::engine::{Engine, general_purpose::STANDARD};
112 use serde::{Deserialize, Deserializer, Serialize, Serializer};
113
114 pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
115 STANDARD.encode(bytes).serialize(s)
116 }
117
118 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
119 let s = String::deserialize(d)?;
120 STANDARD.decode(s).map_err(serde::de::Error::custom)
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, Default)]
126#[non_exhaustive]
127pub enum ExecutionMode {
128 Fit {
130 y: Option<Value>,
132 #[serde(default)]
135 batch_size: Option<usize>,
136 },
137 #[default]
139 Forward,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct SerializedPlan {
145 pub plan_id: PlanId,
146 pub plan: ExecutionPlan,
147 pub input: Option<InputSource>,
149 #[serde(default)]
151 pub filters: Vec<SerializedFilter>,
152 #[serde(default)]
154 pub mode: ExecutionMode,
155 pub metadata: serde_json::Value,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160#[serde(tag = "type")]
161pub enum WorkerToCoordinator {
162 Register {
164 worker_id: WorkerId,
165 capabilities: Capabilities,
166 },
167
168 Heartbeat {
170 worker_id: WorkerId,
171 load: LoadMetrics,
172 },
173
174 Event {
176 worker_id: WorkerId,
177 plan_id: PlanId,
178 event: Event,
179 },
180
181 PlanResult {
183 worker_id: WorkerId,
184 plan_id: PlanId,
185 result: PlanResult,
186 },
187
188 JobProgress {
190 worker_id: WorkerId,
191 job_id: String,
192 phase: String,
193 step: u32,
194 total: u32,
195 metrics: serde_json::Value,
196 },
197
198 JobResult {
200 worker_id: WorkerId,
201 job_id: String,
202 success: bool,
203 metrics: serde_json::Value,
204 output: String,
205 duration_ms: u64,
206 },
207
208 StateResult {
211 worker_id: WorkerId,
212 plan_id: PlanId,
213 states: std::collections::HashMap<String, Value>,
214 },
215
216 GradientsResult {
218 worker_id: WorkerId,
219 plan_id: PlanId,
220 gradients: std::collections::HashMap<String, Value>,
221 },
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct PythonPipelineJob {
227 pub job_id: String,
228 pub pipeline_id: String,
229 pub investigation_id: String,
230 pub files: Vec<PipelineFile>,
232 pub requirements: String,
234 pub entry_point: String,
236 pub input_data: Option<serde_json::Value>,
238 pub params: serde_json::Value,
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct PipelineFile {
245 pub path: String,
246 pub content: String,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
251#[serde(tag = "type")]
252pub enum CoordinatorToWorker {
253 Registered { worker_id: WorkerId },
255
256 AssignPlan { plan: SerializedPlan },
258
259 AssignPythonJob { job: PythonPipelineJob },
261
262 CancelPlan { plan_id: PlanId },
264
265 StatusRequest,
267
268 Ping,
270
271 Shutdown { reason: String },
273
274 GetState {
277 plan_id: PlanId,
278 node_ids: Vec<String>,
279 },
280
281 SetState {
283 plan_id: PlanId,
284 states: std::collections::HashMap<String, Value>,
285 },
286
287 GetGradients {
289 plan_id: PlanId,
290 node_ids: Vec<String>,
291 },
292
293 ApplyGradients {
295 plan_id: PlanId,
296 gradients: std::collections::HashMap<String, Value>,
297 },
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
302#[serde(tag = "delivery")]
303#[non_exhaustive]
304pub enum OutputDelivery {
305 Inline { value: Value },
307 Reference {
309 data_ref: somatize_core::store::DataRef,
310 },
311}
312
313impl OutputDelivery {
314 pub fn resolve(&self, addr: &str, token: &Option<String>) -> Value {
317 match self {
318 OutputDelivery::Inline { value } => value.clone(),
319 OutputDelivery::Reference { data_ref } => {
320 let http_addr = addr
322 .replace("ws://", "http://")
323 .replace("wss://", "https://");
324 let url = format!("{http_addr}/download");
325 let ref_json = serde_json::to_string(data_ref).unwrap_or_default();
326 let token = token.clone();
327
328 std::thread::spawn(move || {
329 let client = reqwest::blocking::Client::new();
330 let mut req = client.get(&url).query(&[("ref", &ref_json)]);
331 if let Some(t) = &token {
332 req = req.query(&[("token", t.as_str())]);
333 }
334 let resp = req.send().ok()?;
335 let bytes = resp.bytes().ok()?;
336 serde_json::from_slice(&bytes).ok()
337 })
338 .join()
339 .ok()
340 .flatten()
341 .unwrap_or(Value::Empty)
342 }
343 }
344 }
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize)]
349#[serde(tag = "status")]
350pub enum PlanResult {
351 Success {
352 output: OutputDelivery,
353 duration_ms: u64,
354 #[serde(default)]
357 states: std::collections::HashMap<String, Value>,
358 },
359 Failed {
360 error: String,
361 duration_ms: u64,
362 },
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
371#[serde(tag = "type")]
372#[non_exhaustive]
373pub enum StreamMessage {
374 StreamBegin {
376 stream_id: String,
377 plan_id: PlanId,
378 total_chunks: Option<usize>,
380 plan: Box<SerializedPlan>,
382 },
383 ChunkData {
385 stream_id: String,
386 chunk_index: usize,
387 value: Value,
388 },
389 StreamEnd { stream_id: String },
391 ChunkResult {
393 stream_id: String,
394 chunk_index: usize,
395 value: Value,
396 },
397 StreamComplete {
399 stream_id: String,
400 result: PlanResult,
401 },
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use somatize_core::event::PlanSummary;
408
409 #[test]
410 fn capabilities_serde() {
411 let caps = Capabilities {
412 cpu_cores: 8,
413 ram_bytes: 32 * 1024 * 1024 * 1024,
414 gpus: vec![GpuInfo {
415 name: "A100".into(),
416 memory_bytes: 80 * 1024 * 1024 * 1024,
417 }],
418 python_envs: vec!["py310".into(), "py311".into()],
419 tags: vec!["gpu".into(), "training".into()],
420 };
421 let json = serde_json::to_string(&caps).unwrap();
422 let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
423 assert_eq!(deserialized.cpu_cores, 8);
424 assert_eq!(deserialized.gpus.len(), 1);
425 assert_eq!(deserialized.tags, vec!["gpu", "training"]);
426 }
427
428 #[test]
429 fn worker_message_serde() {
430 let msg = WorkerToCoordinator::Register {
431 worker_id: "worker_01".into(),
432 capabilities: Capabilities {
433 cpu_cores: 4,
434 ram_bytes: 16_000_000_000,
435 gpus: vec![],
436 python_envs: vec![],
437 tags: vec!["cpu".into()],
438 },
439 };
440 let json = serde_json::to_string(&msg).unwrap();
441 assert!(json.contains("Register"));
442 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
443 if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
444 assert_eq!(worker_id, "worker_01");
445 } else {
446 panic!("wrong variant");
447 }
448 }
449
450 #[test]
451 fn coordinator_message_serde() {
452 let msg = CoordinatorToWorker::AssignPlan {
453 plan: SerializedPlan {
454 plan_id: "plan_001".into(),
455 plan: ExecutionPlan::Execute {
456 node_id: "train".into(),
457 },
458 input: Some(InputSource::Inline {
459 value: Value::tensor(vec![1.0, 2.0], vec![2]),
460 }),
461 filters: vec![],
462 mode: ExecutionMode::default(),
463 metadata: serde_json::json!({"experiment": "test"}),
464 },
465 };
466 let json = serde_json::to_string(&msg).unwrap();
467 let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
468 assert!(matches!(
469 deserialized,
470 CoordinatorToWorker::AssignPlan { .. }
471 ));
472 }
473
474 #[test]
475 fn plan_result_serde() {
476 let success = PlanResult::Success {
477 output: OutputDelivery::Inline {
478 value: Value::tensor(vec![0.95], vec![1]),
479 },
480 duration_ms: 1234,
481 states: std::collections::HashMap::new(),
482 };
483 let json = serde_json::to_string(&success).unwrap();
484 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
485 assert!(matches!(deserialized, PlanResult::Success { .. }));
486
487 let failed = PlanResult::Failed {
488 error: "OOM".into(),
489 duration_ms: 500,
490 };
491 let json = serde_json::to_string(&failed).unwrap();
492 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
493 assert!(matches!(deserialized, PlanResult::Failed { .. }));
494 }
495
496 #[test]
497 fn event_message_serde() {
498 let msg = WorkerToCoordinator::Event {
499 worker_id: "w1".into(),
500 plan_id: "p1".into(),
501 event: Event::RunStarted {
502 run_id: "r1".into(),
503 plan_summary: PlanSummary {
504 total_nodes: 3,
505 cached_nodes: 1,
506 parallel_branches: 0,
507 },
508 },
509 };
510 let json = serde_json::to_string(&msg).unwrap();
511 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
512 assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
513 }
514
515 #[test]
516 fn heartbeat_serde() {
517 let msg = WorkerToCoordinator::Heartbeat {
518 worker_id: "w1".into(),
519 load: LoadMetrics {
520 cpu_usage: 0.45,
521 memory_usage: 0.72,
522 gpu_usage: vec![0.88],
523 active_plans: 2,
524 queue_depth: 5,
525 timestamp: Utc::now(),
526 },
527 };
528 let json = serde_json::to_string(&msg).unwrap();
529 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
530 if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
531 assert!(load.cpu_usage > 0.0);
532 assert_eq!(load.active_plans, 2);
533 }
534 }
535}