Skip to main content

rust_pipe/transport/
mod.rs

1pub mod docker;
2pub mod ssh;
3pub mod stdio;
4pub mod wasm;
5pub mod websocket;
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use crate::schema::{Task, TaskResult};
11
12/// Configuration for the transport layer (WebSocket server settings).
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TransportConfig {
15    pub host: String,
16    pub port: u16,
17    pub max_connections: u32,
18    pub heartbeat_interval_ms: u64,
19    pub reconnect_delay_ms: u64,
20    pub max_reconnect_attempts: u32,
21}
22
23impl Default for TransportConfig {
24    fn default() -> Self {
25        Self {
26            host: "0.0.0.0".to_string(),
27            port: 9876,
28            max_connections: 100,
29            heartbeat_interval_ms: 5_000,
30            reconnect_delay_ms: 1_000,
31            max_reconnect_attempts: 10,
32        }
33    }
34}
35
36/// Wire protocol message exchanged between dispatcher and workers.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(tag = "type")]
39pub enum Message {
40    TaskDispatch { task: Task },
41    TaskResult { result: TaskResult },
42    Heartbeat { payload: HeartbeatPayload },
43    HeartbeatAck,
44    WorkerRegister { registration: WorkerRegistration },
45    WorkerRegistered { worker_id: String },
46    Backpressure { signal: BackpressureSignal },
47    Kill { task_id: uuid::Uuid, reason: String },
48    Shutdown { graceful: bool },
49}
50
51/// Heartbeat data sent periodically by workers to indicate liveness.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct HeartbeatPayload {
55    pub worker_id: String,
56    pub active_tasks: u32,
57    pub capacity: u32,
58    pub uptime_seconds: u64,
59}
60
61/// Registration payload sent by a worker when it connects.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(rename_all = "camelCase")]
64pub struct WorkerRegistration {
65    pub worker_id: String,
66    pub supported_tasks: Vec<String>,
67    pub max_concurrency: u32,
68    pub language: WorkerLanguage,
69}
70
71/// Programming language of a connected worker.
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub enum WorkerLanguage {
74    TypeScript,
75    Python,
76    Go,
77    Other(String),
78}
79
80/// Signal from a worker indicating it is under heavy load.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub struct BackpressureSignal {
84    pub worker_id: String,
85    pub current_load: f64,
86    pub should_throttle: bool,
87}
88
89/// Trait implemented by all transport backends (WebSocket, stdio, Docker, SSH, WASM).
90#[async_trait]
91pub trait Transport: Send + Sync {
92    async fn start(&self) -> Result<(), TransportError>;
93    async fn stop(&self) -> Result<(), TransportError>;
94    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError>;
95    async fn broadcast(&self, message: Message) -> Result<(), TransportError>;
96}
97
98/// Errors that can occur in any transport backend.
99#[derive(Debug, thiserror::Error)]
100pub enum TransportError {
101    #[error("Connection failed: {0}")]
102    ConnectionFailed(String),
103    #[error("Worker not found: {0}")]
104    WorkerNotFound(String),
105    #[error("Send failed: {0}")]
106    SendFailed(String),
107    #[error("Transport closed")]
108    Closed,
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::schema::*;
115    use serde_json::json;
116    use uuid::Uuid;
117
118    #[test]
119    fn test_transport_config_defaults() {
120        let config = TransportConfig::default();
121        assert_eq!(config.host, "0.0.0.0");
122        assert_eq!(config.port, 9876);
123        assert_eq!(config.max_connections, 100);
124        assert_eq!(config.heartbeat_interval_ms, 5_000);
125        assert_eq!(config.reconnect_delay_ms, 1_000);
126        assert_eq!(config.max_reconnect_attempts, 10);
127    }
128
129    #[test]
130    fn test_message_serde_task_dispatch() {
131        let task = Task::new("scan", json!({"url": "http://x.com"}));
132        let msg = Message::TaskDispatch { task };
133        let json = serde_json::to_string(&msg).unwrap();
134        assert!(json.contains(r#""type":"TaskDispatch""#));
135        let de: Message = serde_json::from_str(&json).unwrap();
136        if let Message::TaskDispatch { task: t } = de {
137            assert_eq!(t.task_type, "scan");
138        } else {
139            panic!("Wrong variant");
140        }
141    }
142
143    #[test]
144    fn test_message_serde_task_result() {
145        let result = TaskResult {
146            task_id: Uuid::new_v4(),
147            status: TaskStatus::Completed,
148            payload: Some(json!({"v": 1})),
149            error: None,
150            duration_ms: 100,
151            worker_id: "w1".into(),
152        };
153        let msg = Message::TaskResult { result };
154        let json = serde_json::to_string(&msg).unwrap();
155        assert!(json.contains(r#""type":"TaskResult""#));
156        let de: Message = serde_json::from_str(&json).unwrap();
157        assert!(matches!(de, Message::TaskResult { .. }));
158    }
159
160    #[test]
161    fn test_message_serde_heartbeat() {
162        let msg = Message::Heartbeat {
163            payload: HeartbeatPayload {
164                worker_id: "w1".into(),
165                active_tasks: 3,
166                capacity: 10,
167                uptime_seconds: 120,
168            },
169        };
170        let json = serde_json::to_string(&msg).unwrap();
171        assert!(json.contains("activeTasks"));
172        assert!(json.contains("uptimeSeconds"));
173        let de: Message = serde_json::from_str(&json).unwrap();
174        assert!(matches!(de, Message::Heartbeat { .. }));
175    }
176
177    #[test]
178    fn test_message_serde_heartbeat_ack() {
179        let msg = Message::HeartbeatAck;
180        let json = serde_json::to_string(&msg).unwrap();
181        assert_eq!(json, r#"{"type":"HeartbeatAck"}"#);
182        let de: Message = serde_json::from_str(&json).unwrap();
183        assert!(matches!(de, Message::HeartbeatAck));
184    }
185
186    #[test]
187    fn test_message_serde_worker_register() {
188        let msg = Message::WorkerRegister {
189            registration: WorkerRegistration {
190                worker_id: "w1".into(),
191                supported_tasks: vec!["a".into(), "b".into()],
192                max_concurrency: 5,
193                language: WorkerLanguage::Python,
194            },
195        };
196        let json = serde_json::to_string(&msg).unwrap();
197        assert!(json.contains("supportedTasks"));
198        assert!(json.contains("maxConcurrency"));
199        let de: Message = serde_json::from_str(&json).unwrap();
200        if let Message::WorkerRegister { registration: reg } = de {
201            assert_eq!(reg.worker_id, "w1");
202            assert_eq!(reg.supported_tasks.len(), 2);
203        } else {
204            panic!("Wrong variant");
205        }
206    }
207
208    #[test]
209    fn test_message_serde_worker_registered() {
210        let msg = Message::WorkerRegistered {
211            worker_id: "w1".into(),
212        };
213        let json = serde_json::to_string(&msg).unwrap();
214        let de: Message = serde_json::from_str(&json).unwrap();
215        if let Message::WorkerRegistered { worker_id } = de {
216            assert_eq!(worker_id, "w1");
217        } else {
218            panic!("Wrong variant");
219        }
220    }
221
222    #[test]
223    fn test_message_serde_backpressure() {
224        let msg = Message::Backpressure {
225            signal: BackpressureSignal {
226                worker_id: "w1".into(),
227                current_load: 0.85,
228                should_throttle: true,
229            },
230        };
231        let json = serde_json::to_string(&msg).unwrap();
232        assert!(json.contains("currentLoad"));
233        assert!(json.contains("shouldThrottle"));
234        let de: Message = serde_json::from_str(&json).unwrap();
235        if let Message::Backpressure { signal } = de {
236            assert!((signal.current_load - 0.85).abs() < f64::EPSILON);
237            assert!(signal.should_throttle);
238        } else {
239            panic!("Wrong variant");
240        }
241    }
242
243    #[test]
244    fn test_message_serde_kill() {
245        let id = Uuid::new_v4();
246        let msg = Message::Kill {
247            task_id: id,
248            reason: "timeout".into(),
249        };
250        let json = serde_json::to_string(&msg).unwrap();
251        let de: Message = serde_json::from_str(&json).unwrap();
252        if let Message::Kill { task_id, reason } = de {
253            assert_eq!(task_id, id);
254            assert_eq!(reason, "timeout");
255        } else {
256            panic!("Wrong variant");
257        }
258    }
259
260    #[test]
261    fn test_message_serde_shutdown_graceful() {
262        let msg = Message::Shutdown { graceful: true };
263        let json = serde_json::to_string(&msg).unwrap();
264        let de: Message = serde_json::from_str(&json).unwrap();
265        assert!(matches!(de, Message::Shutdown { graceful: true }));
266    }
267
268    #[test]
269    fn test_message_serde_shutdown_not_graceful() {
270        let msg = Message::Shutdown { graceful: false };
271        let json = serde_json::to_string(&msg).unwrap();
272        let de: Message = serde_json::from_str(&json).unwrap();
273        assert!(matches!(de, Message::Shutdown { graceful: false }));
274    }
275
276    #[test]
277    fn test_message_internally_tagged() {
278        let msg = Message::HeartbeatAck;
279        let json = serde_json::to_string(&msg).unwrap();
280        let val: serde_json::Value = serde_json::from_str(&json).unwrap();
281        assert!(val.get("type").is_some());
282    }
283
284    #[test]
285    fn test_message_deserialization_rejects_unknown_type() {
286        let bad = r#"{"type":"UnknownMessage","data":{}}"#;
287        let result = serde_json::from_str::<Message>(bad);
288        assert!(result.is_err());
289    }
290
291    #[test]
292    fn test_worker_language_serde_all_variants() {
293        for lang in [
294            WorkerLanguage::TypeScript,
295            WorkerLanguage::Python,
296            WorkerLanguage::Go,
297        ] {
298            let json = serde_json::to_string(&lang).unwrap();
299            let de: WorkerLanguage = serde_json::from_str(&json).unwrap();
300            assert_eq!(format!("{:?}", de), format!("{:?}", lang));
301        }
302    }
303
304    #[test]
305    fn test_worker_language_serde_other() {
306        let lang = WorkerLanguage::Other("ruby".into());
307        let json = serde_json::to_string(&lang).unwrap();
308        let de: WorkerLanguage = serde_json::from_str(&json).unwrap();
309        if let WorkerLanguage::Other(s) = de {
310            assert_eq!(s, "ruby");
311        } else {
312            panic!("Expected Other variant");
313        }
314    }
315
316    #[test]
317    fn test_transport_error_display() {
318        assert_eq!(
319            TransportError::ConnectionFailed("timeout".into()).to_string(),
320            "Connection failed: timeout"
321        );
322        assert_eq!(
323            TransportError::WorkerNotFound("w1".into()).to_string(),
324            "Worker not found: w1"
325        );
326        assert_eq!(
327            TransportError::SendFailed("broken".into()).to_string(),
328            "Send failed: broken"
329        );
330        assert_eq!(TransportError::Closed.to_string(), "Transport closed");
331    }
332
333    #[test]
334    fn test_heartbeat_payload_camel_case() {
335        let p = HeartbeatPayload {
336            worker_id: "w".into(),
337            active_tasks: 1,
338            capacity: 5,
339            uptime_seconds: 60,
340        };
341        let json = serde_json::to_string(&p).unwrap();
342        assert!(json.contains("workerId"));
343        assert!(json.contains("activeTasks"));
344        assert!(json.contains("uptimeSeconds"));
345        assert!(!json.contains("worker_id"));
346    }
347
348    #[test]
349    fn test_worker_registration_serde_roundtrip() {
350        let reg = WorkerRegistration {
351            worker_id: "w1".into(),
352            supported_tasks: vec!["a".into(), "b".into()],
353            max_concurrency: 10,
354            language: WorkerLanguage::Go,
355        };
356        let json = serde_json::to_string(&reg).unwrap();
357        let de: WorkerRegistration = serde_json::from_str(&json).unwrap();
358        assert_eq!(de.worker_id, "w1");
359        assert_eq!(de.supported_tasks.len(), 2);
360        assert_eq!(de.max_concurrency, 10);
361    }
362
363    #[test]
364    fn test_worker_registration_empty_tasks() {
365        let reg = WorkerRegistration {
366            worker_id: "w".into(),
367            supported_tasks: vec![],
368            max_concurrency: 1,
369            language: WorkerLanguage::TypeScript,
370        };
371        let json = serde_json::to_string(&reg).unwrap();
372        let de: WorkerRegistration = serde_json::from_str(&json).unwrap();
373        assert!(de.supported_tasks.is_empty());
374    }
375
376    #[test]
377    fn test_backpressure_signal_serde() {
378        let sig = BackpressureSignal {
379            worker_id: "w1".into(),
380            current_load: 0.95,
381            should_throttle: true,
382        };
383        let json = serde_json::to_string(&sig).unwrap();
384        assert!(json.contains("currentLoad"));
385        assert!(json.contains("shouldThrottle"));
386        let de: BackpressureSignal = serde_json::from_str(&json).unwrap();
387        assert!((de.current_load - 0.95).abs() < f64::EPSILON);
388    }
389}