Skip to main content

rust_pipe/dispatch/
mod.rs

1use dashmap::DashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::{oneshot, RwLock};
6use tokio::task::JoinHandle;
7use uuid::Uuid;
8
9use crate::schema::{Task, TaskResult};
10use crate::transport::websocket::WebSocketTransport;
11use crate::transport::{Message, Transport, TransportConfig, TransportError};
12use crate::worker::{WorkerInfo, WorkerPool, WorkerStatus};
13
14/// Builder for configuring a [`Dispatcher`].
15#[derive(Debug, Clone)]
16pub struct DispatcherBuilder {
17    config: TransportConfig,
18    heartbeat_timeout_ms: u64,
19    dead_worker_check_interval_ms: u64,
20}
21
22impl Default for DispatcherBuilder {
23    fn default() -> Self {
24        Self {
25            config: TransportConfig::default(),
26            heartbeat_timeout_ms: 15_000,
27            dead_worker_check_interval_ms: 5_000,
28        }
29    }
30}
31
32impl DispatcherBuilder {
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    pub fn host(mut self, host: impl Into<String>) -> Self {
38        self.config.host = host.into();
39        self
40    }
41
42    pub fn port(mut self, port: u16) -> Self {
43        self.config.port = port;
44        self
45    }
46
47    pub fn max_connections(mut self, max: u32) -> Self {
48        self.config.max_connections = max;
49        self
50    }
51
52    pub fn heartbeat_interval(mut self, ms: u64) -> Self {
53        self.config.heartbeat_interval_ms = ms;
54        self
55    }
56
57    pub fn heartbeat_timeout(mut self, ms: u64) -> Self {
58        self.heartbeat_timeout_ms = ms;
59        self
60    }
61
62    pub fn build(self) -> Dispatcher {
63        Dispatcher {
64            pool: Arc::new(WorkerPool::new(self.heartbeat_timeout_ms)),
65            pending: Arc::new(DashMap::new()),
66            transport: Arc::new(RwLock::new(None)),
67            config: self.config,
68            dead_worker_check_interval_ms: self.dead_worker_check_interval_ms,
69            started: AtomicBool::new(false),
70            _dead_worker_task: RwLock::new(None),
71        }
72    }
73}
74
75/// Central task dispatcher. Manages worker connections and routes tasks.
76pub struct Dispatcher {
77    pool: Arc<WorkerPool>,
78    pending: Arc<DashMap<Uuid, PendingTask>>,
79    transport: Arc<RwLock<Option<Arc<WebSocketTransport>>>>,
80    config: TransportConfig,
81    dead_worker_check_interval_ms: u64,
82    started: AtomicBool,
83    _dead_worker_task: RwLock<Option<JoinHandle<()>>>,
84}
85
86struct PendingTask {
87    sender: oneshot::Sender<TaskResult>,
88    worker_id: String,
89}
90
91/// Handle to an in-flight task. Await it to get the worker's result.
92#[must_use = "dropping a DispatchResult discards the task result"]
93pub struct DispatchResult {
94    pub task_id: Uuid,
95    pub(crate) receiver: oneshot::Receiver<TaskResult>,
96}
97
98impl std::fmt::Debug for DispatchResult {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct("DispatchResult")
101            .field("task_id", &self.task_id)
102            .finish()
103    }
104}
105
106impl DispatchResult {
107    pub async fn await_result(self) -> Result<TaskResult, DispatchError> {
108        self.receiver
109            .await
110            .map_err(|_| DispatchError::WorkerDisconnected)
111    }
112
113    pub async fn await_with_timeout(self, timeout: Duration) -> Result<TaskResult, DispatchError> {
114        tokio::time::timeout(timeout, self.receiver)
115            .await
116            .map_err(|_| DispatchError::Timeout)?
117            .map_err(|_| DispatchError::WorkerDisconnected)
118    }
119}
120
121impl Dispatcher {
122    pub fn builder() -> DispatcherBuilder {
123        DispatcherBuilder::new()
124    }
125
126    pub async fn start(&self) -> Result<(), DispatchError> {
127        if self.started.swap(true, Ordering::SeqCst) {
128            return Ok(());
129        }
130
131        let pool = self.pool.clone();
132        let pending = self.pending.clone();
133
134        let on_message = move |worker_id: String, message: Message| {
135            let pool = pool.clone();
136            let pending = pending.clone();
137
138            tokio::spawn(async move {
139                match message {
140                    Message::WorkerRegister { registration: reg } => {
141                        pool.register(WorkerInfo {
142                            id: reg.worker_id,
143                            language: reg.language,
144                            supported_tasks: reg.supported_tasks,
145                            max_concurrency: reg.max_concurrency,
146                            status: WorkerStatus::Active,
147                            active_tasks: 0,
148                            registered_at: chrono::Utc::now(),
149                            last_heartbeat: chrono::Utc::now(),
150                        });
151                    }
152                    Message::TaskResult { result } => {
153                        pool.mark_task_completed(&worker_id);
154                        if let Some((_, pending_task)) = pending.remove(&result.task_id) {
155                            let _ = pending_task.sender.send(result);
156                        }
157                    }
158                    Message::Heartbeat { payload: hb } => {
159                        pool.heartbeat(&hb.worker_id, hb.active_tasks);
160                    }
161                    Message::Backpressure { signal: bp } => {
162                        tracing::warn!(
163                            worker_id = %bp.worker_id,
164                            load = bp.current_load,
165                            "Worker signaled backpressure"
166                        );
167                    }
168                    _ => {}
169                }
170            });
171        };
172
173        let transport = Arc::new(WebSocketTransport::new(self.config.clone(), on_message));
174        transport
175            .start()
176            .await
177            .map_err(DispatchError::TransportError)?;
178
179        *self.transport.write().await = Some(transport);
180
181        // Dead worker detection loop — drops pending senders for dead workers
182        // so that await_result() returns WorkerDisconnected instead of hanging.
183        let pool = self.pool.clone();
184        let pending = self.pending.clone();
185        let interval = self.dead_worker_check_interval_ms;
186        let handle = tokio::spawn(async move {
187            loop {
188                tokio::time::sleep(Duration::from_millis(interval)).await;
189                let dead = pool.detect_dead_workers();
190                if !dead.is_empty() {
191                    for worker_id in &dead {
192                        tracing::warn!(worker_id = %worker_id, "Dead worker detected");
193                    }
194                    // Drop pending senders for dead workers — receiver gets RecvError
195                    // which maps to DispatchError::WorkerDisconnected
196                    pending
197                        .retain(|_task_id, pending_task| !dead.contains(&pending_task.worker_id));
198                }
199            }
200        });
201
202        *self._dead_worker_task.write().await = Some(handle);
203
204        Ok(())
205    }
206
207    /// Gracefully stops the dispatcher. Cancels background tasks and shuts down transport.
208    pub async fn stop(&self) {
209        self.started.store(false, Ordering::SeqCst);
210        // Cancel the dead worker detection task
211        if let Some(handle) = self._dead_worker_task.write().await.take() {
212            handle.abort();
213        }
214        // Shut down transport
215        if let Some(transport) = self.transport.read().await.as_ref() {
216            let _ = transport.stop().await;
217        }
218        // Fail all pending tasks
219        self.pending.clear();
220    }
221
222    pub async fn dispatch(&self, task: Task) -> Result<DispatchResult, DispatchError> {
223        // select_and_reserve atomically picks a worker and increments active_tasks
224        let worker_id = self.pool.select_and_reserve(&task.task_type).ok_or(
225            DispatchError::NoWorkerAvailable {
226                task_type: task.task_type.clone(),
227            },
228        )?;
229
230        let (tx, rx) = oneshot::channel();
231        let task_id = task.id;
232
233        self.pending.insert(
234            task_id,
235            PendingTask {
236                sender: tx,
237                worker_id: worker_id.clone(),
238            },
239        );
240
241        // Send task to worker via transport
242        let transport_guard = self.transport.read().await;
243        let transport = transport_guard.as_ref().ok_or_else(|| {
244            // Rollback: remove pending and release worker capacity
245            self.pending.remove(&task_id);
246            self.pool.mark_task_completed(&worker_id);
247            DispatchError::TransportNotStarted
248        })?;
249
250        if let Err(e) = transport
251            .send(&worker_id, Message::TaskDispatch { task })
252            .await
253        {
254            // Rollback: remove pending and release worker capacity
255            self.pending.remove(&task_id);
256            self.pool.mark_task_completed(&worker_id);
257            return Err(DispatchError::TransportError(e));
258        }
259
260        tracing::debug!(task_id = %task_id, worker_id = %worker_id, "Task dispatched");
261
262        Ok(DispatchResult {
263            task_id,
264            receiver: rx,
265        })
266    }
267
268    pub fn pool_stats(&self) -> crate::worker::PoolStats {
269        self.pool.stats()
270    }
271}
272
273#[derive(Debug, thiserror::Error)]
274pub enum DispatchError {
275    #[error("No worker available for task type: {task_type}")]
276    NoWorkerAvailable { task_type: String },
277
278    #[error("Worker disconnected before returning result")]
279    WorkerDisconnected,
280
281    #[error("Task timed out")]
282    Timeout,
283
284    #[error("Transport not started — call start() first")]
285    TransportNotStarted,
286
287    #[error("Transport error: {0}")]
288    TransportError(#[from] TransportError),
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::schema::{TaskResult, TaskStatus};
295    use serde_json::json;
296
297    #[test]
298    fn test_builder_default_port() {
299        let builder = DispatcherBuilder::new();
300        assert_eq!(builder.config.port, 9876);
301    }
302
303    #[test]
304    fn test_builder_default_host() {
305        let builder = DispatcherBuilder::new();
306        assert_eq!(builder.config.host, "0.0.0.0");
307    }
308
309    #[test]
310    fn test_builder_default_heartbeat_timeout() {
311        let builder = DispatcherBuilder::new();
312        assert_eq!(builder.heartbeat_timeout_ms, 15_000);
313    }
314
315    #[test]
316    fn test_builder_host_sets_value() {
317        let builder = DispatcherBuilder::new().host("10.0.0.1");
318        assert_eq!(builder.config.host, "10.0.0.1");
319    }
320
321    #[test]
322    fn test_builder_port_sets_value() {
323        let builder = DispatcherBuilder::new().port(8080);
324        assert_eq!(builder.config.port, 8080);
325    }
326
327    #[test]
328    fn test_builder_max_connections_sets_value() {
329        let builder = DispatcherBuilder::new().max_connections(50);
330        assert_eq!(builder.config.max_connections, 50);
331    }
332
333    #[test]
334    fn test_builder_heartbeat_interval_sets_value() {
335        let builder = DispatcherBuilder::new().heartbeat_interval(2000);
336        assert_eq!(builder.config.heartbeat_interval_ms, 2000);
337    }
338
339    #[test]
340    fn test_builder_heartbeat_timeout_sets_value() {
341        let builder = DispatcherBuilder::new().heartbeat_timeout(30000);
342        assert_eq!(builder.heartbeat_timeout_ms, 30000);
343    }
344
345    #[test]
346    fn test_builder_chaining() {
347        let builder = DispatcherBuilder::new()
348            .host("1.2.3.4")
349            .port(9999)
350            .max_connections(200)
351            .heartbeat_interval(1000)
352            .heartbeat_timeout(5000);
353        assert_eq!(builder.config.host, "1.2.3.4");
354        assert_eq!(builder.config.port, 9999);
355        assert_eq!(builder.config.max_connections, 200);
356        assert_eq!(builder.config.heartbeat_interval_ms, 1000);
357        assert_eq!(builder.heartbeat_timeout_ms, 5000);
358    }
359
360    #[test]
361    fn test_builder_build_pool_starts_empty() {
362        let dispatcher = Dispatcher::builder().build();
363        let stats = dispatcher.pool_stats();
364        assert_eq!(stats.total, 0);
365    }
366
367    #[test]
368    fn test_dispatcher_builder_shortcut() {
369        let builder = Dispatcher::builder();
370        assert_eq!(builder.config.port, 9876);
371    }
372
373    #[tokio::test]
374    async fn test_dispatch_result_await_result_receives_value() {
375        let (tx, rx) = oneshot::channel();
376        let result = DispatchResult {
377            task_id: Uuid::new_v4(),
378            receiver: rx,
379        };
380        let task_result = TaskResult {
381            task_id: result.task_id,
382            status: TaskStatus::Completed,
383            payload: Some(json!({"ok": true})),
384            error: None,
385            duration_ms: 50,
386            worker_id: "test".to_string(),
387        };
388        tx.send(task_result.clone()).unwrap();
389        let received = result.await_result().await.unwrap();
390        assert_eq!(received.task_id, task_result.task_id);
391        assert_eq!(received.status, TaskStatus::Completed);
392    }
393
394    #[tokio::test]
395    async fn test_dispatch_result_worker_disconnected() {
396        let (tx, rx) = oneshot::channel::<TaskResult>();
397        let result = DispatchResult {
398            task_id: Uuid::new_v4(),
399            receiver: rx,
400        };
401        drop(tx);
402        let err = result.await_result().await.unwrap_err();
403        assert!(matches!(err, DispatchError::WorkerDisconnected));
404    }
405
406    #[tokio::test]
407    async fn test_dispatch_result_timeout() {
408        let (_tx, rx) = oneshot::channel::<TaskResult>();
409        let result = DispatchResult {
410            task_id: Uuid::new_v4(),
411            receiver: rx,
412        };
413        let err = result
414            .await_with_timeout(Duration::from_millis(10))
415            .await
416            .unwrap_err();
417        assert!(matches!(err, DispatchError::Timeout));
418    }
419
420    #[test]
421    fn test_dispatch_result_debug_format() {
422        let (_tx, rx) = oneshot::channel::<TaskResult>();
423        let id = Uuid::new_v4();
424        let result = DispatchResult {
425            task_id: id,
426            receiver: rx,
427        };
428        let debug = format!("{:?}", result);
429        assert!(debug.contains("DispatchResult"));
430        assert!(debug.contains(&id.to_string()));
431    }
432
433    #[test]
434    fn test_dispatch_error_display_no_worker() {
435        let err = DispatchError::NoWorkerAvailable {
436            task_type: "scan".into(),
437        };
438        assert_eq!(err.to_string(), "No worker available for task type: scan");
439    }
440
441    #[test]
442    fn test_dispatch_error_display_worker_disconnected() {
443        let err = DispatchError::WorkerDisconnected;
444        assert_eq!(
445            err.to_string(),
446            "Worker disconnected before returning result"
447        );
448    }
449
450    #[test]
451    fn test_dispatch_error_display_timeout() {
452        let err = DispatchError::Timeout;
453        assert_eq!(err.to_string(), "Task timed out");
454    }
455
456    #[test]
457    fn test_dispatch_error_display_transport_not_started() {
458        let err = DispatchError::TransportNotStarted;
459        assert!(err.to_string().contains("Transport not started"));
460    }
461
462    #[test]
463    fn test_dispatch_error_from_transport_error() {
464        let transport_err = TransportError::Closed;
465        let dispatch_err: DispatchError = transport_err.into();
466        assert!(matches!(
467            dispatch_err,
468            DispatchError::TransportError(TransportError::Closed)
469        ));
470    }
471}