Skip to main content

rust_pipe/transport/
mod.rs

1pub mod docker;
2pub mod ssh;
3pub mod stdio;
4pub mod wasm;
5pub mod websocket;
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use crate::schema::{Task, TaskResult};
11
12/// Configuration for the transport layer (WebSocket server settings).
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TransportConfig {
15    pub host: String,
16    pub port: u16,
17    pub max_connections: u32,
18    pub heartbeat_interval_ms: u64,
19    pub reconnect_delay_ms: u64,
20    pub max_reconnect_attempts: u32,
21}
22
23impl Default for TransportConfig {
24    fn default() -> Self {
25        Self {
26            host: "0.0.0.0".to_string(),
27            port: 9876,
28            max_connections: 100,
29            heartbeat_interval_ms: 5_000,
30            reconnect_delay_ms: 1_000,
31            max_reconnect_attempts: 10,
32        }
33    }
34}
35
36/// Wire protocol message exchanged between dispatcher and workers.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(tag = "type")]
39pub enum Message {
40    TaskDispatch { task: Task },
41    TaskResult { result: TaskResult },
42    Heartbeat { payload: HeartbeatPayload },
43    HeartbeatAck,
44    WorkerRegister { registration: WorkerRegistration },
45    WorkerRegistered { worker_id: String },
46    Backpressure { signal: BackpressureSignal },
47    Kill { task_id: uuid::Uuid, reason: String },
48    Shutdown { graceful: bool },
49}
50
51/// Heartbeat data sent periodically by workers to indicate liveness.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct HeartbeatPayload {
55    pub worker_id: String,
56    pub active_tasks: u32,
57    pub capacity: u32,
58    pub uptime_seconds: u64,
59}
60
61/// Registration payload sent by a worker when it connects.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(rename_all = "camelCase")]
64pub struct WorkerRegistration {
65    pub worker_id: String,
66    pub supported_tasks: Vec<String>,
67    pub max_concurrency: u32,
68    pub language: WorkerLanguage,
69    /// Optional tags for targeted dispatch routing.
70    #[serde(default)]
71    pub tags: Option<Vec<String>>,
72}
73
74/// Programming language of a connected worker.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum WorkerLanguage {
77    TypeScript,
78    Python,
79    Go,
80    Other(String),
81}
82
83/// Signal from a worker indicating it is under heavy load.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct BackpressureSignal {
87    pub worker_id: String,
88    pub current_load: f64,
89    pub should_throttle: bool,
90}
91
92/// Trait implemented by all transport backends (WebSocket, stdio, Docker, SSH, WASM).
93#[async_trait]
94pub trait Transport: Send + Sync {
95    async fn start(&self) -> Result<(), TransportError>;
96    async fn stop(&self) -> Result<(), TransportError>;
97    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError>;
98    async fn broadcast(&self, message: Message) -> Result<(), TransportError>;
99}
100
101/// Errors that can occur in any transport backend.
102#[derive(Debug, thiserror::Error)]
103pub enum TransportError {
104    #[error("Connection failed: {0}")]
105    ConnectionFailed(String),
106    #[error("Worker not found: {0}")]
107    WorkerNotFound(String),
108    #[error("Send failed: {0}")]
109    SendFailed(String),
110    #[error("Transport closed")]
111    Closed,
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::schema::*;
118    use serde_json::json;
119    use uuid::Uuid;
120
121    #[test]
122    fn test_transport_config_defaults() {
123        let config = TransportConfig::default();
124        assert_eq!(config.host, "0.0.0.0");
125        assert_eq!(config.port, 9876);
126        assert_eq!(config.max_connections, 100);
127        assert_eq!(config.heartbeat_interval_ms, 5_000);
128        assert_eq!(config.reconnect_delay_ms, 1_000);
129        assert_eq!(config.max_reconnect_attempts, 10);
130    }
131
132    #[test]
133    fn test_message_serde_task_dispatch() {
134        let task = Task::new("scan", json!({"url": "http://x.com"}));
135        let msg = Message::TaskDispatch { task };
136        let json = serde_json::to_string(&msg).unwrap();
137        assert!(json.contains(r#""type":"TaskDispatch""#));
138        let de: Message = serde_json::from_str(&json).unwrap();
139        if let Message::TaskDispatch { task: t } = de {
140            assert_eq!(t.task_type, "scan");
141        } else {
142            panic!("Wrong variant");
143        }
144    }
145
146    #[test]
147    fn test_message_serde_task_result() {
148        let result = TaskResult {
149            task_id: Uuid::new_v4(),
150            status: TaskStatus::Completed,
151            payload: Some(json!({"v": 1})),
152            error: None,
153            duration_ms: 100,
154            worker_id: "w1".into(),
155        };
156        let msg = Message::TaskResult { result };
157        let json = serde_json::to_string(&msg).unwrap();
158        assert!(json.contains(r#""type":"TaskResult""#));
159        let de: Message = serde_json::from_str(&json).unwrap();
160        assert!(matches!(de, Message::TaskResult { .. }));
161    }
162
163    #[test]
164    fn test_message_serde_heartbeat() {
165        let msg = Message::Heartbeat {
166            payload: HeartbeatPayload {
167                worker_id: "w1".into(),
168                active_tasks: 3,
169                capacity: 10,
170                uptime_seconds: 120,
171            },
172        };
173        let json = serde_json::to_string(&msg).unwrap();
174        assert!(json.contains("activeTasks"));
175        assert!(json.contains("uptimeSeconds"));
176        let de: Message = serde_json::from_str(&json).unwrap();
177        assert!(matches!(de, Message::Heartbeat { .. }));
178    }
179
180    #[test]
181    fn test_message_serde_heartbeat_ack() {
182        let msg = Message::HeartbeatAck;
183        let json = serde_json::to_string(&msg).unwrap();
184        assert_eq!(json, r#"{"type":"HeartbeatAck"}"#);
185        let de: Message = serde_json::from_str(&json).unwrap();
186        assert!(matches!(de, Message::HeartbeatAck));
187    }
188
189    #[test]
190    fn test_message_serde_worker_register() {
191        let msg = Message::WorkerRegister {
192            registration: WorkerRegistration {
193                worker_id: "w1".into(),
194                supported_tasks: vec!["a".into(), "b".into()],
195                max_concurrency: 5,
196                language: WorkerLanguage::Python,
197                tags: None,
198            },
199        };
200        let json = serde_json::to_string(&msg).unwrap();
201        assert!(json.contains("supportedTasks"));
202        assert!(json.contains("maxConcurrency"));
203        let de: Message = serde_json::from_str(&json).unwrap();
204        if let Message::WorkerRegister { registration: reg } = de {
205            assert_eq!(reg.worker_id, "w1");
206            assert_eq!(reg.supported_tasks.len(), 2);
207        } else {
208            panic!("Wrong variant");
209        }
210    }
211
212    #[test]
213    fn test_message_serde_worker_registered() {
214        let msg = Message::WorkerRegistered {
215            worker_id: "w1".into(),
216        };
217        let json = serde_json::to_string(&msg).unwrap();
218        let de: Message = serde_json::from_str(&json).unwrap();
219        if let Message::WorkerRegistered { worker_id } = de {
220            assert_eq!(worker_id, "w1");
221        } else {
222            panic!("Wrong variant");
223        }
224    }
225
226    #[test]
227    fn test_message_serde_backpressure() {
228        let msg = Message::Backpressure {
229            signal: BackpressureSignal {
230                worker_id: "w1".into(),
231                current_load: 0.85,
232                should_throttle: true,
233            },
234        };
235        let json = serde_json::to_string(&msg).unwrap();
236        assert!(json.contains("currentLoad"));
237        assert!(json.contains("shouldThrottle"));
238        let de: Message = serde_json::from_str(&json).unwrap();
239        if let Message::Backpressure { signal } = de {
240            assert!((signal.current_load - 0.85).abs() < f64::EPSILON);
241            assert!(signal.should_throttle);
242        } else {
243            panic!("Wrong variant");
244        }
245    }
246
247    #[test]
248    fn test_message_serde_kill() {
249        let id = Uuid::new_v4();
250        let msg = Message::Kill {
251            task_id: id,
252            reason: "timeout".into(),
253        };
254        let json = serde_json::to_string(&msg).unwrap();
255        let de: Message = serde_json::from_str(&json).unwrap();
256        if let Message::Kill { task_id, reason } = de {
257            assert_eq!(task_id, id);
258            assert_eq!(reason, "timeout");
259        } else {
260            panic!("Wrong variant");
261        }
262    }
263
264    #[test]
265    fn test_message_serde_shutdown_graceful() {
266        let msg = Message::Shutdown { graceful: true };
267        let json = serde_json::to_string(&msg).unwrap();
268        let de: Message = serde_json::from_str(&json).unwrap();
269        assert!(matches!(de, Message::Shutdown { graceful: true }));
270    }
271
272    #[test]
273    fn test_message_serde_shutdown_not_graceful() {
274        let msg = Message::Shutdown { graceful: false };
275        let json = serde_json::to_string(&msg).unwrap();
276        let de: Message = serde_json::from_str(&json).unwrap();
277        assert!(matches!(de, Message::Shutdown { graceful: false }));
278    }
279
280    #[test]
281    fn test_message_internally_tagged() {
282        let msg = Message::HeartbeatAck;
283        let json = serde_json::to_string(&msg).unwrap();
284        let val: serde_json::Value = serde_json::from_str(&json).unwrap();
285        assert!(val.get("type").is_some());
286    }
287
288    #[test]
289    fn test_message_deserialization_rejects_unknown_type() {
290        let bad = r#"{"type":"UnknownMessage","data":{}}"#;
291        let result = serde_json::from_str::<Message>(bad);
292        assert!(result.is_err());
293    }
294
295    #[test]
296    fn test_worker_language_serde_all_variants() {
297        for lang in [
298            WorkerLanguage::TypeScript,
299            WorkerLanguage::Python,
300            WorkerLanguage::Go,
301        ] {
302            let json = serde_json::to_string(&lang).unwrap();
303            let de: WorkerLanguage = serde_json::from_str(&json).unwrap();
304            assert_eq!(format!("{:?}", de), format!("{:?}", lang));
305        }
306    }
307
308    #[test]
309    fn test_worker_language_serde_other() {
310        let lang = WorkerLanguage::Other("ruby".into());
311        let json = serde_json::to_string(&lang).unwrap();
312        let de: WorkerLanguage = serde_json::from_str(&json).unwrap();
313        if let WorkerLanguage::Other(s) = de {
314            assert_eq!(s, "ruby");
315        } else {
316            panic!("Expected Other variant");
317        }
318    }
319
320    #[test]
321    fn test_transport_error_display() {
322        assert_eq!(
323            TransportError::ConnectionFailed("timeout".into()).to_string(),
324            "Connection failed: timeout"
325        );
326        assert_eq!(
327            TransportError::WorkerNotFound("w1".into()).to_string(),
328            "Worker not found: w1"
329        );
330        assert_eq!(
331            TransportError::SendFailed("broken".into()).to_string(),
332            "Send failed: broken"
333        );
334        assert_eq!(TransportError::Closed.to_string(), "Transport closed");
335    }
336
337    #[test]
338    fn test_heartbeat_payload_camel_case() {
339        let p = HeartbeatPayload {
340            worker_id: "w".into(),
341            active_tasks: 1,
342            capacity: 5,
343            uptime_seconds: 60,
344        };
345        let json = serde_json::to_string(&p).unwrap();
346        assert!(json.contains("workerId"));
347        assert!(json.contains("activeTasks"));
348        assert!(json.contains("uptimeSeconds"));
349        assert!(!json.contains("worker_id"));
350    }
351
352    #[test]
353    fn test_worker_registration_serde_roundtrip() {
354        let reg = WorkerRegistration {
355            worker_id: "w1".into(),
356            supported_tasks: vec!["a".into(), "b".into()],
357            max_concurrency: 10,
358            language: WorkerLanguage::Go,
359            tags: Some(vec!["gpu".into()]),
360        };
361        let json = serde_json::to_string(&reg).unwrap();
362        let de: WorkerRegistration = serde_json::from_str(&json).unwrap();
363        assert_eq!(de.worker_id, "w1");
364        assert_eq!(de.supported_tasks.len(), 2);
365        assert_eq!(de.max_concurrency, 10);
366        assert_eq!(de.tags, Some(vec!["gpu".to_string()]));
367    }
368
369    #[test]
370    fn test_worker_registration_empty_tasks() {
371        let reg = WorkerRegistration {
372            worker_id: "w".into(),
373            supported_tasks: vec![],
374            max_concurrency: 1,
375            language: WorkerLanguage::TypeScript,
376            tags: None,
377        };
378        let json = serde_json::to_string(&reg).unwrap();
379        let de: WorkerRegistration = serde_json::from_str(&json).unwrap();
380        assert!(de.supported_tasks.is_empty());
381    }
382
383    #[test]
384    fn test_backpressure_signal_serde() {
385        let sig = BackpressureSignal {
386            worker_id: "w1".into(),
387            current_load: 0.95,
388            should_throttle: true,
389        };
390        let json = serde_json::to_string(&sig).unwrap();
391        assert!(json.contains("currentLoad"));
392        assert!(json.contains("shouldThrottle"));
393        let de: BackpressureSignal = serde_json::from_str(&json).unwrap();
394        assert!((de.current_load - 0.95).abs() < f64::EPSILON);
395    }
396}