1pub mod docker;
2pub mod ssh;
3pub mod stdio;
4pub mod wasm;
5pub mod websocket;
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use crate::schema::{Task, TaskResult};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TransportConfig {
15 pub host: String,
16 pub port: u16,
17 pub max_connections: u32,
18 pub heartbeat_interval_ms: u64,
19 pub reconnect_delay_ms: u64,
20 pub max_reconnect_attempts: u32,
21}
22
23impl Default for TransportConfig {
24 fn default() -> Self {
25 Self {
26 host: "0.0.0.0".to_string(),
27 port: 9876,
28 max_connections: 100,
29 heartbeat_interval_ms: 5_000,
30 reconnect_delay_ms: 1_000,
31 max_reconnect_attempts: 10,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(tag = "type")]
39pub enum Message {
40 TaskDispatch { task: Task },
41 TaskResult { result: TaskResult },
42 Heartbeat { payload: HeartbeatPayload },
43 HeartbeatAck,
44 WorkerRegister { registration: WorkerRegistration },
45 WorkerRegistered { worker_id: String },
46 Backpressure { signal: BackpressureSignal },
47 Kill { task_id: uuid::Uuid, reason: String },
48 Shutdown { graceful: bool },
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct HeartbeatPayload {
55 pub worker_id: String,
56 pub active_tasks: u32,
57 pub capacity: u32,
58 pub uptime_seconds: u64,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(rename_all = "camelCase")]
64pub struct WorkerRegistration {
65 pub worker_id: String,
66 pub supported_tasks: Vec<String>,
67 pub max_concurrency: u32,
68 pub language: WorkerLanguage,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub enum WorkerLanguage {
74 TypeScript,
75 Python,
76 Go,
77 Other(String),
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub struct BackpressureSignal {
84 pub worker_id: String,
85 pub current_load: f64,
86 pub should_throttle: bool,
87}
88
89#[async_trait]
91pub trait Transport: Send + Sync {
92 async fn start(&self) -> Result<(), TransportError>;
93 async fn stop(&self) -> Result<(), TransportError>;
94 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError>;
95 async fn broadcast(&self, message: Message) -> Result<(), TransportError>;
96}
97
98#[derive(Debug, thiserror::Error)]
100pub enum TransportError {
101 #[error("Connection failed: {0}")]
102 ConnectionFailed(String),
103 #[error("Worker not found: {0}")]
104 WorkerNotFound(String),
105 #[error("Send failed: {0}")]
106 SendFailed(String),
107 #[error("Transport closed")]
108 Closed,
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::schema::*;
115 use serde_json::json;
116 use uuid::Uuid;
117
118 #[test]
119 fn test_transport_config_defaults() {
120 let config = TransportConfig::default();
121 assert_eq!(config.host, "0.0.0.0");
122 assert_eq!(config.port, 9876);
123 assert_eq!(config.max_connections, 100);
124 assert_eq!(config.heartbeat_interval_ms, 5_000);
125 assert_eq!(config.reconnect_delay_ms, 1_000);
126 assert_eq!(config.max_reconnect_attempts, 10);
127 }
128
129 #[test]
130 fn test_message_serde_task_dispatch() {
131 let task = Task::new("scan", json!({"url": "http://x.com"}));
132 let msg = Message::TaskDispatch { task };
133 let json = serde_json::to_string(&msg).unwrap();
134 assert!(json.contains(r#""type":"TaskDispatch""#));
135 let de: Message = serde_json::from_str(&json).unwrap();
136 if let Message::TaskDispatch { task: t } = de {
137 assert_eq!(t.task_type, "scan");
138 } else {
139 panic!("Wrong variant");
140 }
141 }
142
143 #[test]
144 fn test_message_serde_task_result() {
145 let result = TaskResult {
146 task_id: Uuid::new_v4(),
147 status: TaskStatus::Completed,
148 payload: Some(json!({"v": 1})),
149 error: None,
150 duration_ms: 100,
151 worker_id: "w1".into(),
152 };
153 let msg = Message::TaskResult { result };
154 let json = serde_json::to_string(&msg).unwrap();
155 assert!(json.contains(r#""type":"TaskResult""#));
156 let de: Message = serde_json::from_str(&json).unwrap();
157 assert!(matches!(de, Message::TaskResult { .. }));
158 }
159
160 #[test]
161 fn test_message_serde_heartbeat() {
162 let msg = Message::Heartbeat {
163 payload: HeartbeatPayload {
164 worker_id: "w1".into(),
165 active_tasks: 3,
166 capacity: 10,
167 uptime_seconds: 120,
168 },
169 };
170 let json = serde_json::to_string(&msg).unwrap();
171 assert!(json.contains("activeTasks"));
172 assert!(json.contains("uptimeSeconds"));
173 let de: Message = serde_json::from_str(&json).unwrap();
174 assert!(matches!(de, Message::Heartbeat { .. }));
175 }
176
177 #[test]
178 fn test_message_serde_heartbeat_ack() {
179 let msg = Message::HeartbeatAck;
180 let json = serde_json::to_string(&msg).unwrap();
181 assert_eq!(json, r#"{"type":"HeartbeatAck"}"#);
182 let de: Message = serde_json::from_str(&json).unwrap();
183 assert!(matches!(de, Message::HeartbeatAck));
184 }
185
186 #[test]
187 fn test_message_serde_worker_register() {
188 let msg = Message::WorkerRegister {
189 registration: WorkerRegistration {
190 worker_id: "w1".into(),
191 supported_tasks: vec!["a".into(), "b".into()],
192 max_concurrency: 5,
193 language: WorkerLanguage::Python,
194 },
195 };
196 let json = serde_json::to_string(&msg).unwrap();
197 assert!(json.contains("supportedTasks"));
198 assert!(json.contains("maxConcurrency"));
199 let de: Message = serde_json::from_str(&json).unwrap();
200 if let Message::WorkerRegister { registration: reg } = de {
201 assert_eq!(reg.worker_id, "w1");
202 assert_eq!(reg.supported_tasks.len(), 2);
203 } else {
204 panic!("Wrong variant");
205 }
206 }
207
208 #[test]
209 fn test_message_serde_worker_registered() {
210 let msg = Message::WorkerRegistered {
211 worker_id: "w1".into(),
212 };
213 let json = serde_json::to_string(&msg).unwrap();
214 let de: Message = serde_json::from_str(&json).unwrap();
215 if let Message::WorkerRegistered { worker_id } = de {
216 assert_eq!(worker_id, "w1");
217 } else {
218 panic!("Wrong variant");
219 }
220 }
221
222 #[test]
223 fn test_message_serde_backpressure() {
224 let msg = Message::Backpressure {
225 signal: BackpressureSignal {
226 worker_id: "w1".into(),
227 current_load: 0.85,
228 should_throttle: true,
229 },
230 };
231 let json = serde_json::to_string(&msg).unwrap();
232 assert!(json.contains("currentLoad"));
233 assert!(json.contains("shouldThrottle"));
234 let de: Message = serde_json::from_str(&json).unwrap();
235 if let Message::Backpressure { signal } = de {
236 assert!((signal.current_load - 0.85).abs() < f64::EPSILON);
237 assert!(signal.should_throttle);
238 } else {
239 panic!("Wrong variant");
240 }
241 }
242
243 #[test]
244 fn test_message_serde_kill() {
245 let id = Uuid::new_v4();
246 let msg = Message::Kill {
247 task_id: id,
248 reason: "timeout".into(),
249 };
250 let json = serde_json::to_string(&msg).unwrap();
251 let de: Message = serde_json::from_str(&json).unwrap();
252 if let Message::Kill { task_id, reason } = de {
253 assert_eq!(task_id, id);
254 assert_eq!(reason, "timeout");
255 } else {
256 panic!("Wrong variant");
257 }
258 }
259
260 #[test]
261 fn test_message_serde_shutdown_graceful() {
262 let msg = Message::Shutdown { graceful: true };
263 let json = serde_json::to_string(&msg).unwrap();
264 let de: Message = serde_json::from_str(&json).unwrap();
265 assert!(matches!(de, Message::Shutdown { graceful: true }));
266 }
267
268 #[test]
269 fn test_message_serde_shutdown_not_graceful() {
270 let msg = Message::Shutdown { graceful: false };
271 let json = serde_json::to_string(&msg).unwrap();
272 let de: Message = serde_json::from_str(&json).unwrap();
273 assert!(matches!(de, Message::Shutdown { graceful: false }));
274 }
275
276 #[test]
277 fn test_message_internally_tagged() {
278 let msg = Message::HeartbeatAck;
279 let json = serde_json::to_string(&msg).unwrap();
280 let val: serde_json::Value = serde_json::from_str(&json).unwrap();
281 assert!(val.get("type").is_some());
282 }
283
284 #[test]
285 fn test_message_deserialization_rejects_unknown_type() {
286 let bad = r#"{"type":"UnknownMessage","data":{}}"#;
287 let result = serde_json::from_str::<Message>(bad);
288 assert!(result.is_err());
289 }
290
291 #[test]
292 fn test_worker_language_serde_all_variants() {
293 for lang in [
294 WorkerLanguage::TypeScript,
295 WorkerLanguage::Python,
296 WorkerLanguage::Go,
297 ] {
298 let json = serde_json::to_string(&lang).unwrap();
299 let de: WorkerLanguage = serde_json::from_str(&json).unwrap();
300 assert_eq!(format!("{:?}", de), format!("{:?}", lang));
301 }
302 }
303
304 #[test]
305 fn test_worker_language_serde_other() {
306 let lang = WorkerLanguage::Other("ruby".into());
307 let json = serde_json::to_string(&lang).unwrap();
308 let de: WorkerLanguage = serde_json::from_str(&json).unwrap();
309 if let WorkerLanguage::Other(s) = de {
310 assert_eq!(s, "ruby");
311 } else {
312 panic!("Expected Other variant");
313 }
314 }
315
316 #[test]
317 fn test_transport_error_display() {
318 assert_eq!(
319 TransportError::ConnectionFailed("timeout".into()).to_string(),
320 "Connection failed: timeout"
321 );
322 assert_eq!(
323 TransportError::WorkerNotFound("w1".into()).to_string(),
324 "Worker not found: w1"
325 );
326 assert_eq!(
327 TransportError::SendFailed("broken".into()).to_string(),
328 "Send failed: broken"
329 );
330 assert_eq!(TransportError::Closed.to_string(), "Transport closed");
331 }
332
333 #[test]
334 fn test_heartbeat_payload_camel_case() {
335 let p = HeartbeatPayload {
336 worker_id: "w".into(),
337 active_tasks: 1,
338 capacity: 5,
339 uptime_seconds: 60,
340 };
341 let json = serde_json::to_string(&p).unwrap();
342 assert!(json.contains("workerId"));
343 assert!(json.contains("activeTasks"));
344 assert!(json.contains("uptimeSeconds"));
345 assert!(!json.contains("worker_id"));
346 }
347
348 #[test]
349 fn test_worker_registration_serde_roundtrip() {
350 let reg = WorkerRegistration {
351 worker_id: "w1".into(),
352 supported_tasks: vec!["a".into(), "b".into()],
353 max_concurrency: 10,
354 language: WorkerLanguage::Go,
355 };
356 let json = serde_json::to_string(®).unwrap();
357 let de: WorkerRegistration = serde_json::from_str(&json).unwrap();
358 assert_eq!(de.worker_id, "w1");
359 assert_eq!(de.supported_tasks.len(), 2);
360 assert_eq!(de.max_concurrency, 10);
361 }
362
363 #[test]
364 fn test_worker_registration_empty_tasks() {
365 let reg = WorkerRegistration {
366 worker_id: "w".into(),
367 supported_tasks: vec![],
368 max_concurrency: 1,
369 language: WorkerLanguage::TypeScript,
370 };
371 let json = serde_json::to_string(®).unwrap();
372 let de: WorkerRegistration = serde_json::from_str(&json).unwrap();
373 assert!(de.supported_tasks.is_empty());
374 }
375
376 #[test]
377 fn test_backpressure_signal_serde() {
378 let sig = BackpressureSignal {
379 worker_id: "w1".into(),
380 current_load: 0.95,
381 should_throttle: true,
382 };
383 let json = serde_json::to_string(&sig).unwrap();
384 assert!(json.contains("currentLoad"));
385 assert!(json.contains("shouldThrottle"));
386 let de: BackpressureSignal = serde_json::from_str(&json).unwrap();
387 assert!((de.current_load - 0.95).abs() < f64::EPSILON);
388 }
389}