Skip to main content

somatize_worker/
protocol.rs

1//! Wire protocol for coordinator ↔ worker communication.
2//!
3//! Defines message types for plan assignment, results, heartbeats,
4//! Python job management, and worker capabilities.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use somatize_compiler::ExecutionPlan;
9use somatize_core::event::Event;
10use somatize_core::store::{DataRef, DataStore};
11use somatize_core::value::Value;
12
13/// Unique worker identifier.
14pub type WorkerId = String;
15
16/// Unique plan execution identifier.
17pub type PlanId = String;
18
19/// Hardware and software capabilities of a worker.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Capabilities {
22    /// Number of CPU cores.
23    pub cpu_cores: usize,
24    /// Total RAM in bytes.
25    pub ram_bytes: u64,
26    /// GPU information.
27    pub gpus: Vec<GpuInfo>,
28    /// Available Python environments.
29    pub python_envs: Vec<String>,
30    /// User-defined tags for routing (e.g. "gpu", "training", "inference").
31    pub tags: Vec<String>,
32}
33
34/// GPU hardware info.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GpuInfo {
37    pub name: String,
38    pub memory_bytes: u64,
39}
40
41/// Current load metrics reported by a worker.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LoadMetrics {
44    pub cpu_usage: f32,
45    pub memory_usage: f32,
46    pub gpu_usage: Vec<f32>,
47    pub active_plans: usize,
48    pub queue_depth: usize,
49    pub timestamp: DateTime<Utc>,
50}
51
52/// How input data is provided to a worker.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "source")]
55#[non_exhaustive]
56pub enum InputSource {
57    /// Data embedded directly in the message (small payloads).
58    Inline { value: Value },
59    /// Data referenced in a remote store (large payloads).
60    Reference { data_ref: DataRef },
61}
62
63impl InputSource {
64    /// Resolve the input to a concrete Value.
65    /// Tries persistent DataStore first, then temp store for HTTP uploads.
66    pub fn resolve(
67        &self,
68        data_store: Option<&dyn somatize_core::store::DataStore>,
69        temp_store: &somatize_core::store::LocalDataStore,
70    ) -> Value {
71        match self {
72            InputSource::Inline { value } => value.clone(),
73            InputSource::Reference { data_ref } => {
74                if let Some(store) = data_store
75                    && let Ok(val) = store.get(data_ref)
76                {
77                    return val;
78                }
79                temp_store.get(data_ref).unwrap_or_else(|e| {
80                    tracing::warn!("Failed to resolve DataRef: {e}");
81                    Value::Empty
82                })
83            }
84        }
85    }
86}
87
88/// A serialized filter: cloudpickle bytes to reconstruct on the worker.
89///
90/// Uses cloudpickle (like Spark/Dask/Ray) to serialize the full Python object
91/// including bytecode, closures, and cross-module dependencies.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct SerializedFilter {
94    /// Node ID this filter is registered under.
95    pub node_id: String,
96    /// cloudpickle.dumps() bytes (base64-encoded for JSON transport).
97    #[serde(with = "base64_bytes")]
98    pub pickled_filter: Vec<u8>,
99    /// Trained state (if fitted).
100    pub state: Option<Value>,
101    /// Pip requirements detected from the filter's imports (e.g. ["torch", "transformers"]).
102    #[serde(default)]
103    pub requirements: Vec<String>,
104    /// Whether the filter is trainable (has meaningful fit()) or stateless.
105    #[serde(default)]
106    pub trainable: bool,
107}
108
109/// Serde helper: Vec<u8> ↔ base64 string for JSON-safe binary transport.
110mod base64_bytes {
111    use base64::engine::{Engine, general_purpose::STANDARD};
112    use serde::{Deserialize, Deserializer, Serialize, Serializer};
113
114    pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
115        STANDARD.encode(bytes).serialize(s)
116    }
117
118    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
119        let s = String::deserialize(d)?;
120        STANDARD.decode(s).map_err(serde::de::Error::custom)
121    }
122}
123
124/// Execution mode: fit (training) or forward (inference).
125#[derive(Debug, Clone, Serialize, Deserialize, Default)]
126#[non_exhaustive]
127pub enum ExecutionMode {
128    /// Training: fit each filter, then forward to propagate outputs.
129    Fit {
130        /// Supervised labels (optional).
131        y: Option<Value>,
132    },
133    /// Inference: forward only (default).
134    #[default]
135    Forward,
136}
137
138/// A serialized plan ready for remote execution.
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct SerializedPlan {
141    pub plan_id: PlanId,
142    pub plan: ExecutionPlan,
143    /// Input data — inline for small values, DataRef for large ones.
144    pub input: Option<InputSource>,
145    /// Filter definitions for the worker to reconstruct.
146    #[serde(default)]
147    pub filters: Vec<SerializedFilter>,
148    /// Fit or Forward.
149    #[serde(default)]
150    pub mode: ExecutionMode,
151    pub metadata: serde_json::Value,
152}
153
154/// Messages from Worker → Coordinator.
155#[derive(Debug, Clone, Serialize, Deserialize)]
156#[serde(tag = "type")]
157pub enum WorkerToCoordinator {
158    /// Worker announces itself.
159    Register {
160        worker_id: WorkerId,
161        capabilities: Capabilities,
162    },
163
164    /// Periodic health check.
165    Heartbeat {
166        worker_id: WorkerId,
167        load: LoadMetrics,
168    },
169
170    /// Execution event streamed back in real-time.
171    Event {
172        worker_id: WorkerId,
173        plan_id: PlanId,
174        event: Event,
175    },
176
177    /// Plan execution completed.
178    PlanResult {
179        worker_id: WorkerId,
180        plan_id: PlanId,
181        result: PlanResult,
182    },
183
184    /// Python job progress update.
185    JobProgress {
186        worker_id: WorkerId,
187        job_id: String,
188        phase: String,
189        step: u32,
190        total: u32,
191        metrics: serde_json::Value,
192    },
193
194    /// Python job result.
195    JobResult {
196        worker_id: WorkerId,
197        job_id: String,
198        success: bool,
199        metrics: serde_json::Value,
200        output: String,
201        duration_ms: u64,
202    },
203
204    // ── Distributed training responses ──
205    /// Response to GetState: trained filter states.
206    StateResult {
207        worker_id: WorkerId,
208        plan_id: PlanId,
209        states: std::collections::HashMap<String, Value>,
210    },
211
212    /// Response to GetGradients: gradient data.
213    GradientsResult {
214        worker_id: WorkerId,
215        plan_id: PlanId,
216        gradients: std::collections::HashMap<String, Value>,
217    },
218}
219
220/// A Python pipeline job: source files + requirements for isolated execution.
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct PythonPipelineJob {
223    pub job_id: String,
224    pub pipeline_id: String,
225    pub investigation_id: String,
226    /// Source files: path → content
227    pub files: Vec<PipelineFile>,
228    /// pip requirements (content of requirements.txt)
229    pub requirements: String,
230    /// Entry point: which file/function to execute
231    pub entry_point: String,
232    /// Input data (JSON-serialized)
233    pub input_data: Option<serde_json::Value>,
234    /// Extra parameters
235    pub params: serde_json::Value,
236}
237
238/// A source file in a pipeline job.
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct PipelineFile {
241    pub path: String,
242    pub content: String,
243}
244
245/// Messages from Coordinator → Worker.
246#[derive(Debug, Clone, Serialize, Deserialize)]
247#[serde(tag = "type")]
248pub enum CoordinatorToWorker {
249    /// Accept worker registration.
250    Registered { worker_id: WorkerId },
251
252    /// Assign a native Soma plan for execution.
253    AssignPlan { plan: SerializedPlan },
254
255    /// Assign a Python pipeline job (with environment isolation).
256    AssignPythonJob { job: PythonPipelineJob },
257
258    /// Cancel a running plan/job.
259    CancelPlan { plan_id: PlanId },
260
261    /// Request current status.
262    StatusRequest,
263
264    /// Ping for keepalive.
265    Ping,
266
267    /// Graceful shutdown: worker should finish running plans and exit.
268    Shutdown { reason: String },
269
270    // ── Distributed training messages ──
271    /// Request trained states from specific filters.
272    GetState {
273        plan_id: PlanId,
274        node_ids: Vec<String>,
275    },
276
277    /// Load states into filters (e.g. after FedAvg aggregation).
278    SetState {
279        plan_id: PlanId,
280        states: std::collections::HashMap<String, Value>,
281    },
282
283    /// Request gradients from filters (for AllReduce in DataParallel).
284    GetGradients {
285        plan_id: PlanId,
286        node_ids: Vec<String>,
287    },
288
289    /// Apply aggregated gradients (after AllReduce).
290    ApplyGradients {
291        plan_id: PlanId,
292        gradients: std::collections::HashMap<String, Value>,
293    },
294}
295
296/// How output is delivered in PlanResult.
297#[derive(Debug, Clone, Serialize, Deserialize)]
298#[serde(tag = "delivery")]
299#[non_exhaustive]
300pub enum OutputDelivery {
301    /// Small output — embedded directly in the WS message.
302    Inline { value: Value },
303    /// Large output — stored on worker, download via HTTP GET /download?key=...
304    Reference {
305        data_ref: somatize_core::store::DataRef,
306    },
307}
308
309impl OutputDelivery {
310    /// Resolve the output to a concrete Value.
311    /// For Reference: downloads via HTTP from the worker.
312    pub fn resolve(&self, addr: &str, token: &Option<String>) -> Value {
313        match self {
314            OutputDelivery::Inline { value } => value.clone(),
315            OutputDelivery::Reference { data_ref } => {
316                // HTTP download in a dedicated thread (avoids tokio nesting)
317                let http_addr = addr
318                    .replace("ws://", "http://")
319                    .replace("wss://", "https://");
320                let url = format!("{http_addr}/download");
321                let ref_json = serde_json::to_string(data_ref).unwrap_or_default();
322                let token = token.clone();
323
324                std::thread::spawn(move || {
325                    let client = reqwest::blocking::Client::new();
326                    let mut req = client.get(&url).query(&[("ref", &ref_json)]);
327                    if let Some(t) = &token {
328                        req = req.query(&[("token", t.as_str())]);
329                    }
330                    let resp = req.send().ok()?;
331                    let bytes = resp.bytes().ok()?;
332                    serde_json::from_slice(&bytes).ok()
333                })
334                .join()
335                .ok()
336                .flatten()
337                .unwrap_or(Value::Empty)
338            }
339        }
340    }
341}
342
343/// Result of a plan execution.
344#[derive(Debug, Clone, Serialize, Deserialize)]
345#[serde(tag = "status")]
346pub enum PlanResult {
347    Success {
348        output: OutputDelivery,
349        duration_ms: u64,
350        /// Trained states returned after Fit mode (node_id → state).
351        /// Empty for Forward mode.
352        #[serde(default)]
353        states: std::collections::HashMap<String, Value>,
354    },
355    Failed {
356        error: String,
357        duration_ms: u64,
358    },
359}
360
361/// Streaming protocol: chunked data transfer over WebSocket Binary frames.
362///
363/// Wire format: msgpack-encoded StreamMessage (efficient binary, no JSON overhead).
364/// Client sends StreamBegin + N × ChunkData + StreamEnd.
365/// Worker responds with ChunkResult per chunk + StreamComplete at the end.
366#[derive(Debug, Clone, Serialize, Deserialize)]
367#[serde(tag = "type")]
368#[non_exhaustive]
369pub enum StreamMessage {
370    /// Begin a streaming session.
371    StreamBegin {
372        stream_id: String,
373        plan_id: PlanId,
374        /// Number of chunks (None if unknown ahead of time).
375        total_chunks: Option<usize>,
376        /// The plan to execute — input comes via chunks, not inline.
377        plan: Box<SerializedPlan>,
378    },
379    /// A single chunk of input data.
380    ChunkData {
381        stream_id: String,
382        chunk_index: usize,
383        value: Value,
384    },
385    /// All chunks have been sent.
386    StreamEnd { stream_id: String },
387    /// Result for a processed chunk (streamed back to client).
388    ChunkResult {
389        stream_id: String,
390        chunk_index: usize,
391        value: Value,
392    },
393    /// Final result after all chunks processed.
394    StreamComplete {
395        stream_id: String,
396        result: PlanResult,
397    },
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use somatize_core::event::PlanSummary;
404
405    #[test]
406    fn capabilities_serde() {
407        let caps = Capabilities {
408            cpu_cores: 8,
409            ram_bytes: 32 * 1024 * 1024 * 1024,
410            gpus: vec![GpuInfo {
411                name: "A100".into(),
412                memory_bytes: 80 * 1024 * 1024 * 1024,
413            }],
414            python_envs: vec!["py310".into(), "py311".into()],
415            tags: vec!["gpu".into(), "training".into()],
416        };
417        let json = serde_json::to_string(&caps).unwrap();
418        let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
419        assert_eq!(deserialized.cpu_cores, 8);
420        assert_eq!(deserialized.gpus.len(), 1);
421        assert_eq!(deserialized.tags, vec!["gpu", "training"]);
422    }
423
424    #[test]
425    fn worker_message_serde() {
426        let msg = WorkerToCoordinator::Register {
427            worker_id: "worker_01".into(),
428            capabilities: Capabilities {
429                cpu_cores: 4,
430                ram_bytes: 16_000_000_000,
431                gpus: vec![],
432                python_envs: vec![],
433                tags: vec!["cpu".into()],
434            },
435        };
436        let json = serde_json::to_string(&msg).unwrap();
437        assert!(json.contains("Register"));
438        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
439        if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
440            assert_eq!(worker_id, "worker_01");
441        } else {
442            panic!("wrong variant");
443        }
444    }
445
446    #[test]
447    fn coordinator_message_serde() {
448        let msg = CoordinatorToWorker::AssignPlan {
449            plan: SerializedPlan {
450                plan_id: "plan_001".into(),
451                plan: ExecutionPlan::Execute {
452                    node_id: "train".into(),
453                },
454                input: Some(InputSource::Inline {
455                    value: Value::tensor(vec![1.0, 2.0], vec![2]),
456                }),
457                filters: vec![],
458                mode: ExecutionMode::default(),
459                metadata: serde_json::json!({"experiment": "test"}),
460            },
461        };
462        let json = serde_json::to_string(&msg).unwrap();
463        let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
464        assert!(matches!(
465            deserialized,
466            CoordinatorToWorker::AssignPlan { .. }
467        ));
468    }
469
470    #[test]
471    fn plan_result_serde() {
472        let success = PlanResult::Success {
473            output: OutputDelivery::Inline {
474                value: Value::tensor(vec![0.95], vec![1]),
475            },
476            duration_ms: 1234,
477            states: std::collections::HashMap::new(),
478        };
479        let json = serde_json::to_string(&success).unwrap();
480        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
481        assert!(matches!(deserialized, PlanResult::Success { .. }));
482
483        let failed = PlanResult::Failed {
484            error: "OOM".into(),
485            duration_ms: 500,
486        };
487        let json = serde_json::to_string(&failed).unwrap();
488        let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
489        assert!(matches!(deserialized, PlanResult::Failed { .. }));
490    }
491
492    #[test]
493    fn event_message_serde() {
494        let msg = WorkerToCoordinator::Event {
495            worker_id: "w1".into(),
496            plan_id: "p1".into(),
497            event: Event::RunStarted {
498                run_id: "r1".into(),
499                plan_summary: PlanSummary {
500                    total_nodes: 3,
501                    cached_nodes: 1,
502                    parallel_branches: 0,
503                },
504            },
505        };
506        let json = serde_json::to_string(&msg).unwrap();
507        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
508        assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
509    }
510
511    #[test]
512    fn heartbeat_serde() {
513        let msg = WorkerToCoordinator::Heartbeat {
514            worker_id: "w1".into(),
515            load: LoadMetrics {
516                cpu_usage: 0.45,
517                memory_usage: 0.72,
518                gpu_usage: vec![0.88],
519                active_plans: 2,
520                queue_depth: 5,
521                timestamp: Utc::now(),
522            },
523        };
524        let json = serde_json::to_string(&msg).unwrap();
525        let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
526        if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
527            assert!(load.cpu_usage > 0.0);
528            assert_eq!(load.active_plans, 2);
529        }
530    }
531}