1pub mod docker;
2pub mod ssh;
3pub mod stdio;
4pub mod wasm;
5pub mod websocket;
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use crate::schema::{Task, TaskResult};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TransportConfig {
15 pub host: String,
16 pub port: u16,
17 pub max_connections: u32,
18 pub heartbeat_interval_ms: u64,
19 pub reconnect_delay_ms: u64,
20 pub max_reconnect_attempts: u32,
21}
22
23impl Default for TransportConfig {
24 fn default() -> Self {
25 Self {
26 host: "0.0.0.0".to_string(),
27 port: 9876,
28 max_connections: 100,
29 heartbeat_interval_ms: 5_000,
30 reconnect_delay_ms: 1_000,
31 max_reconnect_attempts: 10,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(tag = "type")]
39pub enum Message {
40 TaskDispatch { task: Task },
41 TaskResult { result: TaskResult },
42 Heartbeat { payload: HeartbeatPayload },
43 HeartbeatAck,
44 WorkerRegister { registration: WorkerRegistration },
45 WorkerRegistered { worker_id: String },
46 Backpressure { signal: BackpressureSignal },
47 Kill { task_id: uuid::Uuid, reason: String },
48 Shutdown { graceful: bool },
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct HeartbeatPayload {
55 pub worker_id: String,
56 pub active_tasks: u32,
57 pub capacity: u32,
58 pub uptime_seconds: u64,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(rename_all = "camelCase")]
64pub struct WorkerRegistration {
65 pub worker_id: String,
66 pub supported_tasks: Vec<String>,
67 pub max_concurrency: u32,
68 pub language: WorkerLanguage,
69 #[serde(default)]
71 pub tags: Option<Vec<String>>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum WorkerLanguage {
77 TypeScript,
78 Python,
79 Go,
80 Other(String),
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct BackpressureSignal {
87 pub worker_id: String,
88 pub current_load: f64,
89 pub should_throttle: bool,
90}
91
92#[async_trait]
94pub trait Transport: Send + Sync {
95 async fn start(&self) -> Result<(), TransportError>;
96 async fn stop(&self) -> Result<(), TransportError>;
97 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError>;
98 async fn broadcast(&self, message: Message) -> Result<(), TransportError>;
99}
100
101#[derive(Debug, thiserror::Error)]
103pub enum TransportError {
104 #[error("Connection failed: {0}")]
105 ConnectionFailed(String),
106 #[error("Worker not found: {0}")]
107 WorkerNotFound(String),
108 #[error("Send failed: {0}")]
109 SendFailed(String),
110 #[error("Transport closed")]
111 Closed,
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::schema::*;
118 use serde_json::json;
119 use uuid::Uuid;
120
121 #[test]
122 fn test_transport_config_defaults() {
123 let config = TransportConfig::default();
124 assert_eq!(config.host, "0.0.0.0");
125 assert_eq!(config.port, 9876);
126 assert_eq!(config.max_connections, 100);
127 assert_eq!(config.heartbeat_interval_ms, 5_000);
128 assert_eq!(config.reconnect_delay_ms, 1_000);
129 assert_eq!(config.max_reconnect_attempts, 10);
130 }
131
132 #[test]
133 fn test_message_serde_task_dispatch() {
134 let task = Task::new("scan", json!({"url": "http://x.com"}));
135 let msg = Message::TaskDispatch { task };
136 let json = serde_json::to_string(&msg).unwrap();
137 assert!(json.contains(r#""type":"TaskDispatch""#));
138 let de: Message = serde_json::from_str(&json).unwrap();
139 if let Message::TaskDispatch { task: t } = de {
140 assert_eq!(t.task_type, "scan");
141 } else {
142 panic!("Wrong variant");
143 }
144 }
145
146 #[test]
147 fn test_message_serde_task_result() {
148 let result = TaskResult {
149 task_id: Uuid::new_v4(),
150 status: TaskStatus::Completed,
151 payload: Some(json!({"v": 1})),
152 error: None,
153 duration_ms: 100,
154 worker_id: "w1".into(),
155 };
156 let msg = Message::TaskResult { result };
157 let json = serde_json::to_string(&msg).unwrap();
158 assert!(json.contains(r#""type":"TaskResult""#));
159 let de: Message = serde_json::from_str(&json).unwrap();
160 assert!(matches!(de, Message::TaskResult { .. }));
161 }
162
163 #[test]
164 fn test_message_serde_heartbeat() {
165 let msg = Message::Heartbeat {
166 payload: HeartbeatPayload {
167 worker_id: "w1".into(),
168 active_tasks: 3,
169 capacity: 10,
170 uptime_seconds: 120,
171 },
172 };
173 let json = serde_json::to_string(&msg).unwrap();
174 assert!(json.contains("activeTasks"));
175 assert!(json.contains("uptimeSeconds"));
176 let de: Message = serde_json::from_str(&json).unwrap();
177 assert!(matches!(de, Message::Heartbeat { .. }));
178 }
179
180 #[test]
181 fn test_message_serde_heartbeat_ack() {
182 let msg = Message::HeartbeatAck;
183 let json = serde_json::to_string(&msg).unwrap();
184 assert_eq!(json, r#"{"type":"HeartbeatAck"}"#);
185 let de: Message = serde_json::from_str(&json).unwrap();
186 assert!(matches!(de, Message::HeartbeatAck));
187 }
188
189 #[test]
190 fn test_message_serde_worker_register() {
191 let msg = Message::WorkerRegister {
192 registration: WorkerRegistration {
193 worker_id: "w1".into(),
194 supported_tasks: vec!["a".into(), "b".into()],
195 max_concurrency: 5,
196 language: WorkerLanguage::Python,
197 tags: None,
198 },
199 };
200 let json = serde_json::to_string(&msg).unwrap();
201 assert!(json.contains("supportedTasks"));
202 assert!(json.contains("maxConcurrency"));
203 let de: Message = serde_json::from_str(&json).unwrap();
204 if let Message::WorkerRegister { registration: reg } = de {
205 assert_eq!(reg.worker_id, "w1");
206 assert_eq!(reg.supported_tasks.len(), 2);
207 } else {
208 panic!("Wrong variant");
209 }
210 }
211
212 #[test]
213 fn test_message_serde_worker_registered() {
214 let msg = Message::WorkerRegistered {
215 worker_id: "w1".into(),
216 };
217 let json = serde_json::to_string(&msg).unwrap();
218 let de: Message = serde_json::from_str(&json).unwrap();
219 if let Message::WorkerRegistered { worker_id } = de {
220 assert_eq!(worker_id, "w1");
221 } else {
222 panic!("Wrong variant");
223 }
224 }
225
226 #[test]
227 fn test_message_serde_backpressure() {
228 let msg = Message::Backpressure {
229 signal: BackpressureSignal {
230 worker_id: "w1".into(),
231 current_load: 0.85,
232 should_throttle: true,
233 },
234 };
235 let json = serde_json::to_string(&msg).unwrap();
236 assert!(json.contains("currentLoad"));
237 assert!(json.contains("shouldThrottle"));
238 let de: Message = serde_json::from_str(&json).unwrap();
239 if let Message::Backpressure { signal } = de {
240 assert!((signal.current_load - 0.85).abs() < f64::EPSILON);
241 assert!(signal.should_throttle);
242 } else {
243 panic!("Wrong variant");
244 }
245 }
246
247 #[test]
248 fn test_message_serde_kill() {
249 let id = Uuid::new_v4();
250 let msg = Message::Kill {
251 task_id: id,
252 reason: "timeout".into(),
253 };
254 let json = serde_json::to_string(&msg).unwrap();
255 let de: Message = serde_json::from_str(&json).unwrap();
256 if let Message::Kill { task_id, reason } = de {
257 assert_eq!(task_id, id);
258 assert_eq!(reason, "timeout");
259 } else {
260 panic!("Wrong variant");
261 }
262 }
263
264 #[test]
265 fn test_message_serde_shutdown_graceful() {
266 let msg = Message::Shutdown { graceful: true };
267 let json = serde_json::to_string(&msg).unwrap();
268 let de: Message = serde_json::from_str(&json).unwrap();
269 assert!(matches!(de, Message::Shutdown { graceful: true }));
270 }
271
272 #[test]
273 fn test_message_serde_shutdown_not_graceful() {
274 let msg = Message::Shutdown { graceful: false };
275 let json = serde_json::to_string(&msg).unwrap();
276 let de: Message = serde_json::from_str(&json).unwrap();
277 assert!(matches!(de, Message::Shutdown { graceful: false }));
278 }
279
280 #[test]
281 fn test_message_internally_tagged() {
282 let msg = Message::HeartbeatAck;
283 let json = serde_json::to_string(&msg).unwrap();
284 let val: serde_json::Value = serde_json::from_str(&json).unwrap();
285 assert!(val.get("type").is_some());
286 }
287
288 #[test]
289 fn test_message_deserialization_rejects_unknown_type() {
290 let bad = r#"{"type":"UnknownMessage","data":{}}"#;
291 let result = serde_json::from_str::<Message>(bad);
292 assert!(result.is_err());
293 }
294
295 #[test]
296 fn test_worker_language_serde_all_variants() {
297 for lang in [
298 WorkerLanguage::TypeScript,
299 WorkerLanguage::Python,
300 WorkerLanguage::Go,
301 ] {
302 let json = serde_json::to_string(&lang).unwrap();
303 let de: WorkerLanguage = serde_json::from_str(&json).unwrap();
304 assert_eq!(format!("{:?}", de), format!("{:?}", lang));
305 }
306 }
307
308 #[test]
309 fn test_worker_language_serde_other() {
310 let lang = WorkerLanguage::Other("ruby".into());
311 let json = serde_json::to_string(&lang).unwrap();
312 let de: WorkerLanguage = serde_json::from_str(&json).unwrap();
313 if let WorkerLanguage::Other(s) = de {
314 assert_eq!(s, "ruby");
315 } else {
316 panic!("Expected Other variant");
317 }
318 }
319
320 #[test]
321 fn test_transport_error_display() {
322 assert_eq!(
323 TransportError::ConnectionFailed("timeout".into()).to_string(),
324 "Connection failed: timeout"
325 );
326 assert_eq!(
327 TransportError::WorkerNotFound("w1".into()).to_string(),
328 "Worker not found: w1"
329 );
330 assert_eq!(
331 TransportError::SendFailed("broken".into()).to_string(),
332 "Send failed: broken"
333 );
334 assert_eq!(TransportError::Closed.to_string(), "Transport closed");
335 }
336
337 #[test]
338 fn test_heartbeat_payload_camel_case() {
339 let p = HeartbeatPayload {
340 worker_id: "w".into(),
341 active_tasks: 1,
342 capacity: 5,
343 uptime_seconds: 60,
344 };
345 let json = serde_json::to_string(&p).unwrap();
346 assert!(json.contains("workerId"));
347 assert!(json.contains("activeTasks"));
348 assert!(json.contains("uptimeSeconds"));
349 assert!(!json.contains("worker_id"));
350 }
351
352 #[test]
353 fn test_worker_registration_serde_roundtrip() {
354 let reg = WorkerRegistration {
355 worker_id: "w1".into(),
356 supported_tasks: vec!["a".into(), "b".into()],
357 max_concurrency: 10,
358 language: WorkerLanguage::Go,
359 tags: Some(vec!["gpu".into()]),
360 };
361 let json = serde_json::to_string(®).unwrap();
362 let de: WorkerRegistration = serde_json::from_str(&json).unwrap();
363 assert_eq!(de.worker_id, "w1");
364 assert_eq!(de.supported_tasks.len(), 2);
365 assert_eq!(de.max_concurrency, 10);
366 assert_eq!(de.tags, Some(vec!["gpu".to_string()]));
367 }
368
369 #[test]
370 fn test_worker_registration_empty_tasks() {
371 let reg = WorkerRegistration {
372 worker_id: "w".into(),
373 supported_tasks: vec![],
374 max_concurrency: 1,
375 language: WorkerLanguage::TypeScript,
376 tags: None,
377 };
378 let json = serde_json::to_string(®).unwrap();
379 let de: WorkerRegistration = serde_json::from_str(&json).unwrap();
380 assert!(de.supported_tasks.is_empty());
381 }
382
383 #[test]
384 fn test_backpressure_signal_serde() {
385 let sig = BackpressureSignal {
386 worker_id: "w1".into(),
387 current_load: 0.95,
388 should_throttle: true,
389 };
390 let json = serde_json::to_string(&sig).unwrap();
391 assert!(json.contains("currentLoad"));
392 assert!(json.contains("shouldThrottle"));
393 let de: BackpressureSignal = serde_json::from_str(&json).unwrap();
394 assert!((de.current_load - 0.95).abs() < f64::EPSILON);
395 }
396}