Skip to main content

rust_pipe/dispatch/
mod.rs

1use dashmap::DashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::{oneshot, RwLock};
6use tokio::task::JoinHandle;
7use uuid::Uuid;
8
9use crate::schema::{Task, TaskResult};
10use crate::transport::websocket::WebSocketTransport;
11use crate::transport::{Message, Transport, TransportConfig, TransportError};
12use crate::worker::{PoolError, WorkerInfo, WorkerPool, WorkerStatus};
13
14/// Builder for configuring a [`Dispatcher`].
15pub struct DispatcherBuilder {
16    config: TransportConfig,
17    heartbeat_timeout_ms: u64,
18    dead_worker_check_interval_ms: u64,
19    max_pool_size: Option<u32>,
20    min_pool_size: Option<u32>,
21    on_pool_below_min: Option<Arc<dyn Fn(u32) + Send + Sync>>,
22}
23
24impl std::fmt::Debug for DispatcherBuilder {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("DispatcherBuilder")
27            .field("config", &self.config)
28            .field("heartbeat_timeout_ms", &self.heartbeat_timeout_ms)
29            .field(
30                "dead_worker_check_interval_ms",
31                &self.dead_worker_check_interval_ms,
32            )
33            .field("max_pool_size", &self.max_pool_size)
34            .field("min_pool_size", &self.min_pool_size)
35            .field("on_pool_below_min", &self.on_pool_below_min.is_some())
36            .finish()
37    }
38}
39
40impl Default for DispatcherBuilder {
41    fn default() -> Self {
42        Self {
43            config: TransportConfig::default(),
44            heartbeat_timeout_ms: 15_000,
45            dead_worker_check_interval_ms: 5_000,
46            max_pool_size: None,
47            min_pool_size: None,
48            on_pool_below_min: None,
49        }
50    }
51}
52
53impl DispatcherBuilder {
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    pub fn host(mut self, host: impl Into<String>) -> Self {
59        self.config.host = host.into();
60        self
61    }
62
63    pub fn port(mut self, port: u16) -> Self {
64        self.config.port = port;
65        self
66    }
67
68    pub fn max_connections(mut self, max: u32) -> Self {
69        self.config.max_connections = max;
70        self
71    }
72
73    pub fn heartbeat_interval(mut self, ms: u64) -> Self {
74        self.config.heartbeat_interval_ms = ms;
75        self
76    }
77
78    pub fn heartbeat_timeout(mut self, ms: u64) -> Self {
79        self.heartbeat_timeout_ms = ms;
80        self
81    }
82
83    /// Set the maximum number of workers allowed in the pool.
84    /// Workers connecting beyond this limit will be rejected.
85    pub fn max_pool_size(mut self, max: u32) -> Self {
86        self.max_pool_size = Some(max);
87        self
88    }
89
90    /// Set the minimum pool size. When the pool drops below this threshold,
91    /// the `on_pool_below_min` callback is invoked.
92    pub fn min_pool_size(mut self, min: u32) -> Self {
93        self.min_pool_size = Some(min);
94        self
95    }
96
97    /// Set a callback to be invoked when the pool size drops below `min_pool_size`.
98    /// The callback receives the current pool size.
99    pub fn on_pool_below_min(mut self, cb: impl Fn(u32) + Send + Sync + 'static) -> Self {
100        self.on_pool_below_min = Some(Arc::new(cb));
101        self
102    }
103
104    pub fn build(self) -> Dispatcher {
105        Dispatcher {
106            pool: Arc::new(WorkerPool::with_limits(
107                self.heartbeat_timeout_ms,
108                self.max_pool_size,
109                self.min_pool_size,
110                self.on_pool_below_min,
111            )),
112            pending: Arc::new(DashMap::new()),
113            transport: Arc::new(RwLock::new(None)),
114            config: self.config,
115            dead_worker_check_interval_ms: self.dead_worker_check_interval_ms,
116            started: AtomicBool::new(false),
117            _dead_worker_task: RwLock::new(None),
118        }
119    }
120}
121
122/// Central task dispatcher. Manages worker connections and routes tasks.
123pub struct Dispatcher {
124    pool: Arc<WorkerPool>,
125    pending: Arc<DashMap<Uuid, PendingTask>>,
126    transport: Arc<RwLock<Option<Arc<WebSocketTransport>>>>,
127    config: TransportConfig,
128    dead_worker_check_interval_ms: u64,
129    started: AtomicBool,
130    _dead_worker_task: RwLock<Option<JoinHandle<()>>>,
131}
132
133struct PendingTask {
134    sender: oneshot::Sender<TaskResult>,
135    worker_id: String,
136}
137
138/// Handle to an in-flight task. Await it to get the worker's result.
139#[must_use = "dropping a DispatchResult discards the task result"]
140pub struct DispatchResult {
141    pub task_id: Uuid,
142    pub(crate) receiver: oneshot::Receiver<TaskResult>,
143}
144
145impl std::fmt::Debug for DispatchResult {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        f.debug_struct("DispatchResult")
148            .field("task_id", &self.task_id)
149            .finish()
150    }
151}
152
153impl DispatchResult {
154    pub async fn await_result(self) -> Result<TaskResult, DispatchError> {
155        self.receiver
156            .await
157            .map_err(|_| DispatchError::WorkerDisconnected)
158    }
159
160    pub async fn await_with_timeout(self, timeout: Duration) -> Result<TaskResult, DispatchError> {
161        tokio::time::timeout(timeout, self.receiver)
162            .await
163            .map_err(|_| DispatchError::Timeout)?
164            .map_err(|_| DispatchError::WorkerDisconnected)
165    }
166}
167
168impl Dispatcher {
169    pub fn builder() -> DispatcherBuilder {
170        DispatcherBuilder::new()
171    }
172
173    pub async fn start(&self) -> Result<(), DispatchError> {
174        if self.started.swap(true, Ordering::SeqCst) {
175            return Ok(());
176        }
177
178        let pool = self.pool.clone();
179        let pending = self.pending.clone();
180
181        let on_message = move |worker_id: String, message: Message| {
182            let pool = pool.clone();
183            let pending = pending.clone();
184
185            tokio::spawn(async move {
186                match message {
187                    Message::WorkerRegister { registration: reg } => {
188                        pool.register(WorkerInfo {
189                            id: reg.worker_id,
190                            language: reg.language,
191                            supported_tasks: reg.supported_tasks,
192                            max_concurrency: reg.max_concurrency,
193                            status: WorkerStatus::Active,
194                            active_tasks: 0,
195                            registered_at: chrono::Utc::now(),
196                            last_heartbeat: chrono::Utc::now(),
197                            tags: reg.tags.unwrap_or_default(),
198                        });
199                    }
200                    Message::TaskResult { result } => {
201                        pool.mark_task_completed(&worker_id);
202                        if let Some((_, pending_task)) = pending.remove(&result.task_id) {
203                            let _ = pending_task.sender.send(result);
204                        }
205                    }
206                    Message::Heartbeat { payload: hb } => {
207                        pool.heartbeat(&hb.worker_id, hb.active_tasks);
208                    }
209                    Message::Backpressure { signal: bp } => {
210                        tracing::warn!(
211                            worker_id = %bp.worker_id,
212                            load = bp.current_load,
213                            "Worker signaled backpressure"
214                        );
215                    }
216                    _ => {}
217                }
218            });
219        };
220
221        let transport = Arc::new(WebSocketTransport::new(self.config.clone(), on_message));
222        transport
223            .start()
224            .await
225            .map_err(DispatchError::TransportError)?;
226
227        *self.transport.write().await = Some(transport);
228
229        // Dead worker detection loop — drops pending senders for dead workers
230        // so that await_result() returns WorkerDisconnected instead of hanging.
231        let pool = self.pool.clone();
232        let pending = self.pending.clone();
233        let interval = self.dead_worker_check_interval_ms;
234        let handle = tokio::spawn(async move {
235            loop {
236                tokio::time::sleep(Duration::from_millis(interval)).await;
237                let dead = pool.detect_dead_workers();
238                if !dead.is_empty() {
239                    for worker_id in &dead {
240                        tracing::warn!(worker_id = %worker_id, "Dead worker detected");
241                    }
242                    // Drop pending senders for dead workers — receiver gets RecvError
243                    // which maps to DispatchError::WorkerDisconnected
244                    pending
245                        .retain(|_task_id, pending_task| !dead.contains(&pending_task.worker_id));
246                }
247            }
248        });
249
250        *self._dead_worker_task.write().await = Some(handle);
251
252        Ok(())
253    }
254
255    /// Gracefully stops the dispatcher. Cancels background tasks and shuts down transport.
256    pub async fn stop(&self) {
257        self.started.store(false, Ordering::SeqCst);
258        // Cancel the dead worker detection task
259        if let Some(handle) = self._dead_worker_task.write().await.take() {
260            handle.abort();
261        }
262        // Shut down transport
263        if let Some(transport) = self.transport.read().await.as_ref() {
264            let _ = transport.stop().await;
265        }
266        // Fail all pending tasks
267        self.pending.clear();
268    }
269
270    pub async fn dispatch(&self, task: Task) -> Result<DispatchResult, DispatchError> {
271        // select_and_reserve atomically picks a worker and increments active_tasks
272        let worker_id = self.pool.select_and_reserve(&task.task_type).ok_or(
273            DispatchError::NoWorkerAvailable {
274                task_type: task.task_type.clone(),
275            },
276        )?;
277
278        let (tx, rx) = oneshot::channel();
279        let task_id = task.id;
280
281        self.pending.insert(
282            task_id,
283            PendingTask {
284                sender: tx,
285                worker_id: worker_id.clone(),
286            },
287        );
288
289        // Send task to worker via transport
290        let transport_guard = self.transport.read().await;
291        let transport = transport_guard.as_ref().ok_or_else(|| {
292            // Rollback: remove pending and release worker capacity
293            self.pending.remove(&task_id);
294            self.pool.mark_task_completed(&worker_id);
295            DispatchError::TransportNotStarted
296        })?;
297
298        if let Err(e) = transport
299            .send(&worker_id, Message::TaskDispatch { task })
300            .await
301        {
302            // Rollback: remove pending and release worker capacity
303            self.pending.remove(&task_id);
304            self.pool.mark_task_completed(&worker_id);
305            return Err(DispatchError::TransportError(e));
306        }
307
308        tracing::debug!(task_id = %task_id, worker_id = %worker_id, "Task dispatched");
309
310        Ok(DispatchResult {
311            task_id,
312            receiver: rx,
313        })
314    }
315
316    pub fn pool_stats(&self) -> crate::worker::PoolStats {
317        self.pool.stats()
318    }
319
320    /// List all connected workers with their full info.
321    pub fn workers(&self) -> Vec<WorkerInfo> {
322        self.pool.workers()
323    }
324
325    /// Set a worker's status to Draining. No new tasks will be routed to it,
326    /// but existing in-flight tasks will finish normally.
327    pub fn drain_worker(&self, worker_id: &str) -> Result<(), PoolError> {
328        self.pool.drain_worker(worker_id)
329    }
330
331    /// Force-remove a worker from the pool and fail all pending tasks assigned to it.
332    pub fn remove_worker(&self, worker_id: &str) -> Result<(), PoolError> {
333        self.pool.remove_worker(worker_id)?;
334        // Fail all pending tasks for this worker by dropping their senders
335        self.pending
336            .retain(|_task_id, pending_task| pending_task.worker_id != worker_id);
337        Ok(())
338    }
339
340    /// Dispatch a task to a specific worker by ID, bypassing least-loaded selection.
341    pub async fn dispatch_to(
342        &self,
343        worker_id: &str,
344        task: Task,
345    ) -> Result<DispatchResult, DispatchError> {
346        self.pool.reserve_specific_worker(worker_id)?;
347
348        let (tx, rx) = oneshot::channel();
349        let task_id = task.id;
350
351        self.pending.insert(
352            task_id,
353            PendingTask {
354                sender: tx,
355                worker_id: worker_id.to_string(),
356            },
357        );
358
359        // Send task to worker via transport
360        let transport_guard = self.transport.read().await;
361        let transport = transport_guard.as_ref().ok_or_else(|| {
362            self.pending.remove(&task_id);
363            self.pool.mark_task_completed(worker_id);
364            DispatchError::TransportNotStarted
365        })?;
366
367        if let Err(e) = transport
368            .send(worker_id, Message::TaskDispatch { task })
369            .await
370        {
371            self.pending.remove(&task_id);
372            self.pool.mark_task_completed(worker_id);
373            return Err(DispatchError::TransportError(e));
374        }
375
376        tracing::debug!(task_id = %task_id, worker_id = %worker_id, "Task dispatched to specific worker");
377
378        Ok(DispatchResult {
379            task_id,
380            receiver: rx,
381        })
382    }
383
384    /// Dispatch a task to a worker that has a matching tag.
385    /// Routes to the least-loaded worker among those with the specified tag.
386    pub async fn dispatch_with_tag(
387        &self,
388        tag: &str,
389        task: Task,
390    ) -> Result<DispatchResult, DispatchError> {
391        let worker_id = self
392            .pool
393            .select_and_reserve_with_tag(tag, &task.task_type)
394            .ok_or(DispatchError::NoWorkerAvailable {
395                task_type: task.task_type.clone(),
396            })?;
397
398        let (tx, rx) = oneshot::channel();
399        let task_id = task.id;
400
401        self.pending.insert(
402            task_id,
403            PendingTask {
404                sender: tx,
405                worker_id: worker_id.clone(),
406            },
407        );
408
409        // Send task to worker via transport
410        let transport_guard = self.transport.read().await;
411        let transport = transport_guard.as_ref().ok_or_else(|| {
412            self.pending.remove(&task_id);
413            self.pool.mark_task_completed(&worker_id);
414            DispatchError::TransportNotStarted
415        })?;
416
417        if let Err(e) = transport
418            .send(&worker_id, Message::TaskDispatch { task })
419            .await
420        {
421            self.pending.remove(&task_id);
422            self.pool.mark_task_completed(&worker_id);
423            return Err(DispatchError::TransportError(e));
424        }
425
426        tracing::debug!(task_id = %task_id, worker_id = %worker_id, tag = %tag, "Task dispatched with tag");
427
428        Ok(DispatchResult {
429            task_id,
430            receiver: rx,
431        })
432    }
433}
434
435#[derive(Debug, thiserror::Error)]
436pub enum DispatchError {
437    #[error("No worker available for task type: {task_type}")]
438    NoWorkerAvailable { task_type: String },
439
440    #[error("Worker disconnected before returning result")]
441    WorkerDisconnected,
442
443    #[error("Task timed out")]
444    Timeout,
445
446    #[error("Transport not started — call start() first")]
447    TransportNotStarted,
448
449    #[error("Transport error: {0}")]
450    TransportError(#[from] TransportError),
451
452    #[error("Pool error: {0}")]
453    PoolError(#[from] PoolError),
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use crate::schema::{TaskResult, TaskStatus};
460    use serde_json::json;
461
462    #[test]
463    fn test_builder_default_port() {
464        let builder = DispatcherBuilder::new();
465        assert_eq!(builder.config.port, 9876);
466    }
467
468    #[test]
469    fn test_builder_default_host() {
470        let builder = DispatcherBuilder::new();
471        assert_eq!(builder.config.host, "0.0.0.0");
472    }
473
474    #[test]
475    fn test_builder_default_heartbeat_timeout() {
476        let builder = DispatcherBuilder::new();
477        assert_eq!(builder.heartbeat_timeout_ms, 15_000);
478    }
479
480    #[test]
481    fn test_builder_host_sets_value() {
482        let builder = DispatcherBuilder::new().host("10.0.0.1");
483        assert_eq!(builder.config.host, "10.0.0.1");
484    }
485
486    #[test]
487    fn test_builder_port_sets_value() {
488        let builder = DispatcherBuilder::new().port(8080);
489        assert_eq!(builder.config.port, 8080);
490    }
491
492    #[test]
493    fn test_builder_max_connections_sets_value() {
494        let builder = DispatcherBuilder::new().max_connections(50);
495        assert_eq!(builder.config.max_connections, 50);
496    }
497
498    #[test]
499    fn test_builder_heartbeat_interval_sets_value() {
500        let builder = DispatcherBuilder::new().heartbeat_interval(2000);
501        assert_eq!(builder.config.heartbeat_interval_ms, 2000);
502    }
503
504    #[test]
505    fn test_builder_heartbeat_timeout_sets_value() {
506        let builder = DispatcherBuilder::new().heartbeat_timeout(30000);
507        assert_eq!(builder.heartbeat_timeout_ms, 30000);
508    }
509
510    #[test]
511    fn test_builder_chaining() {
512        let builder = DispatcherBuilder::new()
513            .host("1.2.3.4")
514            .port(9999)
515            .max_connections(200)
516            .heartbeat_interval(1000)
517            .heartbeat_timeout(5000);
518        assert_eq!(builder.config.host, "1.2.3.4");
519        assert_eq!(builder.config.port, 9999);
520        assert_eq!(builder.config.max_connections, 200);
521        assert_eq!(builder.config.heartbeat_interval_ms, 1000);
522        assert_eq!(builder.heartbeat_timeout_ms, 5000);
523    }
524
525    #[test]
526    fn test_builder_build_pool_starts_empty() {
527        let dispatcher = Dispatcher::builder().build();
528        let stats = dispatcher.pool_stats();
529        assert_eq!(stats.total, 0);
530    }
531
532    #[test]
533    fn test_dispatcher_builder_shortcut() {
534        let builder = Dispatcher::builder();
535        assert_eq!(builder.config.port, 9876);
536    }
537
538    #[tokio::test]
539    async fn test_dispatch_result_await_result_receives_value() {
540        let (tx, rx) = oneshot::channel();
541        let result = DispatchResult {
542            task_id: Uuid::new_v4(),
543            receiver: rx,
544        };
545        let task_result = TaskResult {
546            task_id: result.task_id,
547            status: TaskStatus::Completed,
548            payload: Some(json!({"ok": true})),
549            error: None,
550            duration_ms: 50,
551            worker_id: "test".to_string(),
552        };
553        tx.send(task_result.clone()).unwrap();
554        let received = result.await_result().await.unwrap();
555        assert_eq!(received.task_id, task_result.task_id);
556        assert_eq!(received.status, TaskStatus::Completed);
557    }
558
559    #[tokio::test]
560    async fn test_dispatch_result_worker_disconnected() {
561        let (tx, rx) = oneshot::channel::<TaskResult>();
562        let result = DispatchResult {
563            task_id: Uuid::new_v4(),
564            receiver: rx,
565        };
566        drop(tx);
567        let err = result.await_result().await.unwrap_err();
568        assert!(matches!(err, DispatchError::WorkerDisconnected));
569    }
570
571    #[tokio::test]
572    async fn test_dispatch_result_timeout() {
573        let (_tx, rx) = oneshot::channel::<TaskResult>();
574        let result = DispatchResult {
575            task_id: Uuid::new_v4(),
576            receiver: rx,
577        };
578        let err = result
579            .await_with_timeout(Duration::from_millis(10))
580            .await
581            .unwrap_err();
582        assert!(matches!(err, DispatchError::Timeout));
583    }
584
585    #[test]
586    fn test_dispatch_result_debug_format() {
587        let (_tx, rx) = oneshot::channel::<TaskResult>();
588        let id = Uuid::new_v4();
589        let result = DispatchResult {
590            task_id: id,
591            receiver: rx,
592        };
593        let debug = format!("{:?}", result);
594        assert!(debug.contains("DispatchResult"));
595        assert!(debug.contains(&id.to_string()));
596    }
597
598    #[test]
599    fn test_dispatch_error_display_no_worker() {
600        let err = DispatchError::NoWorkerAvailable {
601            task_type: "scan".into(),
602        };
603        assert_eq!(err.to_string(), "No worker available for task type: scan");
604    }
605
606    #[test]
607    fn test_dispatch_error_display_worker_disconnected() {
608        let err = DispatchError::WorkerDisconnected;
609        assert_eq!(
610            err.to_string(),
611            "Worker disconnected before returning result"
612        );
613    }
614
615    #[test]
616    fn test_dispatch_error_display_timeout() {
617        let err = DispatchError::Timeout;
618        assert_eq!(err.to_string(), "Task timed out");
619    }
620
621    #[test]
622    fn test_dispatch_error_display_transport_not_started() {
623        let err = DispatchError::TransportNotStarted;
624        assert!(err.to_string().contains("Transport not started"));
625    }
626
627    #[test]
628    fn test_dispatch_error_from_transport_error() {
629        let transport_err = TransportError::Closed;
630        let dispatch_err: DispatchError = transport_err.into();
631        assert!(matches!(
632            dispatch_err,
633            DispatchError::TransportError(TransportError::Closed)
634        ));
635    }
636
637    // =========================================================================
638    // Pool management builder tests
639    // =========================================================================
640
641    #[test]
642    fn test_builder_max_pool_size() {
643        let builder = DispatcherBuilder::new().max_pool_size(10);
644        assert_eq!(builder.max_pool_size, Some(10));
645    }
646
647    #[test]
648    fn test_builder_min_pool_size() {
649        let builder = DispatcherBuilder::new().min_pool_size(2);
650        assert_eq!(builder.min_pool_size, Some(2));
651    }
652
653    #[test]
654    fn test_builder_on_pool_below_min() {
655        let builder = DispatcherBuilder::new().on_pool_below_min(|_| {});
656        assert!(builder.on_pool_below_min.is_some());
657    }
658
659    #[test]
660    fn test_builder_pool_limits_chaining() {
661        let builder = DispatcherBuilder::new()
662            .max_pool_size(50)
663            .min_pool_size(5)
664            .on_pool_below_min(|_| {});
665        assert_eq!(builder.max_pool_size, Some(50));
666        assert_eq!(builder.min_pool_size, Some(5));
667        assert!(builder.on_pool_below_min.is_some());
668    }
669
670    #[test]
671    fn test_dispatcher_workers_empty() {
672        let dispatcher = Dispatcher::builder().build();
673        assert!(dispatcher.workers().is_empty());
674    }
675
676    #[test]
677    fn test_dispatcher_drain_worker_not_found() {
678        let dispatcher = Dispatcher::builder().build();
679        let err = dispatcher.drain_worker("ghost").unwrap_err();
680        assert!(matches!(err, PoolError::WorkerNotFound { .. }));
681    }
682
683    #[test]
684    fn test_dispatcher_remove_worker_not_found() {
685        let dispatcher = Dispatcher::builder().build();
686        let err = dispatcher.remove_worker("ghost").unwrap_err();
687        assert!(matches!(err, PoolError::WorkerNotFound { .. }));
688    }
689
690    #[test]
691    fn test_builder_debug_format() {
692        let builder = DispatcherBuilder::new()
693            .max_pool_size(10)
694            .min_pool_size(2)
695            .on_pool_below_min(|_| {});
696        let debug = format!("{:?}", builder);
697        assert!(debug.contains("DispatcherBuilder"));
698        assert!(debug.contains("max_pool_size"));
699        assert!(debug.contains("min_pool_size"));
700    }
701}