1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::{DataRef, DataStore};
11use somatize_core::value::Value;
12
13pub type WorkerId = String;
15
16pub type PlanId = String;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22 pub cpu_cores: usize,
24 pub ram_bytes: u64,
26 pub gpus: Vec<GpuInfo>,
28 pub python_envs: Vec<String>,
30 pub tags: Vec<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37 pub name: String,
38 pub memory_bytes: u64,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44 pub cpu_usage: f32,
45 pub memory_usage: f32,
46 pub gpu_usage: Vec<f32>,
47 pub active_plans: usize,
48 pub queue_depth: usize,
49 pub timestamp: DateTime<Utc>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57 Inline { value: Value },
59 Reference { data_ref: DataRef },
61}
62
63impl InputSource {
64 pub fn resolve(
67 &self,
68 data_store: Option<&dyn somatize_core::store::DataStore>,
69 temp_store: &somatize_core::store::LocalDataStore,
70 ) -> Value {
71 match self {
72 InputSource::Inline { value } => value.clone(),
73 InputSource::Reference { data_ref } => {
74 if let Some(store) = data_store
75 && let Ok(val) = store.get(data_ref)
76 {
77 return val;
78 }
79 temp_store.get(data_ref).unwrap_or_else(|e| {
80 tracing::warn!("Failed to resolve DataRef: {e}");
81 Value::Empty
82 })
83 }
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct SerializedFilter {
94 pub node_id: String,
96 #[serde(with = "base64_bytes")]
98 pub pickled_filter: Vec<u8>,
99 pub state: Option<Value>,
101 #[serde(default)]
103 pub requirements: Vec<String>,
104 #[serde(default)]
106 pub trainable: bool,
107}
108
109mod base64_bytes {
111 use base64::engine::{Engine, general_purpose::STANDARD};
112 use serde::{Deserialize, Deserializer, Serialize, Serializer};
113
114 pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
115 STANDARD.encode(bytes).serialize(s)
116 }
117
118 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
119 let s = String::deserialize(d)?;
120 STANDARD.decode(s).map_err(serde::de::Error::custom)
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, Default)]
126#[non_exhaustive]
127pub enum ExecutionMode {
128 Fit {
130 y: Option<Value>,
132 },
133 #[default]
135 Forward,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct SerializedPlan {
141 pub plan_id: PlanId,
142 pub plan: ExecutionPlan,
143 pub input: Option<InputSource>,
145 #[serde(default)]
147 pub filters: Vec<SerializedFilter>,
148 #[serde(default)]
150 pub mode: ExecutionMode,
151 pub metadata: serde_json::Value,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156#[serde(tag = "type")]
157pub enum WorkerToCoordinator {
158 Register {
160 worker_id: WorkerId,
161 capabilities: Capabilities,
162 },
163
164 Heartbeat {
166 worker_id: WorkerId,
167 load: LoadMetrics,
168 },
169
170 Event {
172 worker_id: WorkerId,
173 plan_id: PlanId,
174 event: Event,
175 },
176
177 PlanResult {
179 worker_id: WorkerId,
180 plan_id: PlanId,
181 result: PlanResult,
182 },
183
184 JobProgress {
186 worker_id: WorkerId,
187 job_id: String,
188 phase: String,
189 step: u32,
190 total: u32,
191 metrics: serde_json::Value,
192 },
193
194 JobResult {
196 worker_id: WorkerId,
197 job_id: String,
198 success: bool,
199 metrics: serde_json::Value,
200 output: String,
201 duration_ms: u64,
202 },
203
204 StateResult {
207 worker_id: WorkerId,
208 plan_id: PlanId,
209 states: std::collections::HashMap<String, Value>,
210 },
211
212 GradientsResult {
214 worker_id: WorkerId,
215 plan_id: PlanId,
216 gradients: std::collections::HashMap<String, Value>,
217 },
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct PythonPipelineJob {
223 pub job_id: String,
224 pub pipeline_id: String,
225 pub investigation_id: String,
226 pub files: Vec<PipelineFile>,
228 pub requirements: String,
230 pub entry_point: String,
232 pub input_data: Option<serde_json::Value>,
234 pub params: serde_json::Value,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct PipelineFile {
241 pub path: String,
242 pub content: String,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247#[serde(tag = "type")]
248pub enum CoordinatorToWorker {
249 Registered { worker_id: WorkerId },
251
252 AssignPlan { plan: SerializedPlan },
254
255 AssignPythonJob { job: PythonPipelineJob },
257
258 CancelPlan { plan_id: PlanId },
260
261 StatusRequest,
263
264 Ping,
266
267 Shutdown { reason: String },
269
270 GetState {
273 plan_id: PlanId,
274 node_ids: Vec<String>,
275 },
276
277 SetState {
279 plan_id: PlanId,
280 states: std::collections::HashMap<String, Value>,
281 },
282
283 GetGradients {
285 plan_id: PlanId,
286 node_ids: Vec<String>,
287 },
288
289 ApplyGradients {
291 plan_id: PlanId,
292 gradients: std::collections::HashMap<String, Value>,
293 },
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298#[serde(tag = "delivery")]
299#[non_exhaustive]
300pub enum OutputDelivery {
301 Inline { value: Value },
303 Reference {
305 data_ref: somatize_core::store::DataRef,
306 },
307}
308
309impl OutputDelivery {
310 pub fn resolve(&self, addr: &str, token: &Option<String>) -> Value {
313 match self {
314 OutputDelivery::Inline { value } => value.clone(),
315 OutputDelivery::Reference { data_ref } => {
316 let http_addr = addr
318 .replace("ws://", "http://")
319 .replace("wss://", "https://");
320 let url = format!("{http_addr}/download");
321 let ref_json = serde_json::to_string(data_ref).unwrap_or_default();
322 let token = token.clone();
323
324 std::thread::spawn(move || {
325 let client = reqwest::blocking::Client::new();
326 let mut req = client.get(&url).query(&[("ref", &ref_json)]);
327 if let Some(t) = &token {
328 req = req.query(&[("token", t.as_str())]);
329 }
330 let resp = req.send().ok()?;
331 let bytes = resp.bytes().ok()?;
332 serde_json::from_slice(&bytes).ok()
333 })
334 .join()
335 .ok()
336 .flatten()
337 .unwrap_or(Value::Empty)
338 }
339 }
340 }
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
345#[serde(tag = "status")]
346pub enum PlanResult {
347 Success {
348 output: OutputDelivery,
349 duration_ms: u64,
350 #[serde(default)]
353 states: std::collections::HashMap<String, Value>,
354 },
355 Failed {
356 error: String,
357 duration_ms: u64,
358 },
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
367#[serde(tag = "type")]
368#[non_exhaustive]
369pub enum StreamMessage {
370 StreamBegin {
372 stream_id: String,
373 plan_id: PlanId,
374 total_chunks: Option<usize>,
376 plan: Box<SerializedPlan>,
378 },
379 ChunkData {
381 stream_id: String,
382 chunk_index: usize,
383 value: Value,
384 },
385 StreamEnd { stream_id: String },
387 ChunkResult {
389 stream_id: String,
390 chunk_index: usize,
391 value: Value,
392 },
393 StreamComplete {
395 stream_id: String,
396 result: PlanResult,
397 },
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use somatize_core::event::PlanSummary;
404
405 #[test]
406 fn capabilities_serde() {
407 let caps = Capabilities {
408 cpu_cores: 8,
409 ram_bytes: 32 * 1024 * 1024 * 1024,
410 gpus: vec![GpuInfo {
411 name: "A100".into(),
412 memory_bytes: 80 * 1024 * 1024 * 1024,
413 }],
414 python_envs: vec!["py310".into(), "py311".into()],
415 tags: vec!["gpu".into(), "training".into()],
416 };
417 let json = serde_json::to_string(&caps).unwrap();
418 let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
419 assert_eq!(deserialized.cpu_cores, 8);
420 assert_eq!(deserialized.gpus.len(), 1);
421 assert_eq!(deserialized.tags, vec!["gpu", "training"]);
422 }
423
424 #[test]
425 fn worker_message_serde() {
426 let msg = WorkerToCoordinator::Register {
427 worker_id: "worker_01".into(),
428 capabilities: Capabilities {
429 cpu_cores: 4,
430 ram_bytes: 16_000_000_000,
431 gpus: vec![],
432 python_envs: vec![],
433 tags: vec!["cpu".into()],
434 },
435 };
436 let json = serde_json::to_string(&msg).unwrap();
437 assert!(json.contains("Register"));
438 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
439 if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
440 assert_eq!(worker_id, "worker_01");
441 } else {
442 panic!("wrong variant");
443 }
444 }
445
446 #[test]
447 fn coordinator_message_serde() {
448 let msg = CoordinatorToWorker::AssignPlan {
449 plan: SerializedPlan {
450 plan_id: "plan_001".into(),
451 plan: ExecutionPlan::Execute {
452 node_id: "train".into(),
453 },
454 input: Some(InputSource::Inline {
455 value: Value::tensor(vec![1.0, 2.0], vec![2]),
456 }),
457 filters: vec![],
458 mode: ExecutionMode::default(),
459 metadata: serde_json::json!({"experiment": "test"}),
460 },
461 };
462 let json = serde_json::to_string(&msg).unwrap();
463 let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
464 assert!(matches!(
465 deserialized,
466 CoordinatorToWorker::AssignPlan { .. }
467 ));
468 }
469
470 #[test]
471 fn plan_result_serde() {
472 let success = PlanResult::Success {
473 output: OutputDelivery::Inline {
474 value: Value::tensor(vec![0.95], vec![1]),
475 },
476 duration_ms: 1234,
477 states: std::collections::HashMap::new(),
478 };
479 let json = serde_json::to_string(&success).unwrap();
480 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
481 assert!(matches!(deserialized, PlanResult::Success { .. }));
482
483 let failed = PlanResult::Failed {
484 error: "OOM".into(),
485 duration_ms: 500,
486 };
487 let json = serde_json::to_string(&failed).unwrap();
488 let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
489 assert!(matches!(deserialized, PlanResult::Failed { .. }));
490 }
491
492 #[test]
493 fn event_message_serde() {
494 let msg = WorkerToCoordinator::Event {
495 worker_id: "w1".into(),
496 plan_id: "p1".into(),
497 event: Event::RunStarted {
498 run_id: "r1".into(),
499 plan_summary: PlanSummary {
500 total_nodes: 3,
501 cached_nodes: 1,
502 parallel_branches: 0,
503 },
504 },
505 };
506 let json = serde_json::to_string(&msg).unwrap();
507 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
508 assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
509 }
510
511 #[test]
512 fn heartbeat_serde() {
513 let msg = WorkerToCoordinator::Heartbeat {
514 worker_id: "w1".into(),
515 load: LoadMetrics {
516 cpu_usage: 0.45,
517 memory_usage: 0.72,
518 gpu_usage: vec![0.88],
519 active_plans: 2,
520 queue_depth: 5,
521 timestamp: Utc::now(),
522 },
523 };
524 let json = serde_json::to_string(&msg).unwrap();
525 let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
526 if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
527 assert!(load.cpu_usage > 0.0);
528 assert_eq!(load.active_plans, 2);
529 }
530 }
531}