Skip to main content

scatter_proxy/
task.rs

1use std::collections::HashSet;
2use std::collections::VecDeque;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7use bytes::Bytes;
8use http::{HeaderMap, StatusCode};
9use tokio::sync::Mutex as AsyncMutex;
10use tokio::sync::{oneshot, Notify, Semaphore};
11
12use crate::error::ScatterProxyError;
13
14/// Response from a successful proxied request.
15#[derive(Debug)]
16pub struct ScatterResponse {
17    pub status: StatusCode,
18    pub headers: HeaderMap,
19    pub body: Bytes,
20}
21
22/// Handle returned to the caller when a task is submitted.
23///
24/// Implements `Future<Output = ScatterResponse>` — awaiting it blocks until
25/// the scheduler delivers a successful response.  It will **never** resolve
26/// to an error; the scheduler retries internally forever.
27///
28/// To add a caller-side deadline use [`TaskHandle::with_timeout`].
29#[derive(Debug)]
30pub struct TaskHandle {
31    rx: AsyncMutex<oneshot::Receiver<ScatterResponse>>,
32}
33
34impl TaskHandle {
35    /// Await the result with a caller-side timeout.
36    ///
37    /// Returns `Ok(Some(response))` if the task completes in time, or
38    /// `Ok(None)` if the deadline elapses and the caller should keep waiting on
39    /// the same handle later.
40    pub async fn with_timeout(
41        &self,
42        duration: Duration,
43    ) -> Result<Option<ScatterResponse>, ScatterProxyError> {
44        let mut rx = self.rx.lock().await;
45        match tokio::time::timeout(duration, &mut *rx).await {
46            Ok(Ok(resp)) => Ok(Some(resp)),
47            Ok(Err(_)) => Ok(Some(ScatterResponse {
48                status: StatusCode::BAD_GATEWAY,
49                headers: HeaderMap::new(),
50                body: Bytes::from_static(b"scatter-proxy: internal error - task channel closed"),
51            })),
52            Err(_) => Ok(None),
53        }
54    }
55}
56
57/// Internal task entry stored in the pool queue.
58#[derive(Debug)]
59pub(crate) struct TaskEntry {
60    #[allow(dead_code)]
61    pub id: u64,
62    pub request: reqwest::Request,
63    pub host: String,
64    /// How many scheduling rounds this task has been through (for logging).
65    pub attempts: usize,
66    /// Sender half — the scheduler sends the final `ScatterResponse` here.
67    pub result_tx: Option<oneshot::Sender<ScatterResponse>>,
68    /// Description of the last failure (for debug logging).
69    pub last_error: String,
70}
71
72/// Thread-safe, bounded task pool with async back-pressure.
73///
74/// *   [`submit`](TaskPool::submit) — blocks until capacity is available.
75/// *   [`try_submit`](TaskPool::try_submit) — returns immediately with `Err(PoolFull)` when full.
76/// *   [`submit_timeout`](TaskPool::submit_timeout) — blocks up to a deadline.
77#[derive(Debug)]
78struct DelayedTask {
79    ready_at: Instant,
80    entry: TaskEntry,
81}
82
83impl PartialEq for DelayedTask {
84    fn eq(&self, other: &Self) -> bool {
85        self.ready_at.eq(&other.ready_at)
86    }
87}
88
89impl Eq for DelayedTask {}
90
91impl PartialOrd for DelayedTask {
92    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
93        Some(self.cmp(other))
94    }
95}
96
97impl Ord for DelayedTask {
98    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
99        self.ready_at.cmp(&other.ready_at)
100    }
101}
102
103pub struct TaskPool {
104    queue: Mutex<VecDeque<TaskEntry>>,
105    delayed: Mutex<std::collections::BinaryHeap<std::cmp::Reverse<DelayedTask>>>,
106    capacity: usize,
107    /// Semaphore that tracks how many *task slots* are free.
108    /// A permit is acquired on submit and released when the task permanently
109    /// leaves the system (success delivered to the caller).
110    capacity_sem: Semaphore,
111    next_id: AtomicU64,
112    /// Wakes the scheduler when new work is enqueued.
113    notify: Notify,
114    completed: AtomicU64,
115    failed: AtomicU64,
116    requeued: AtomicU64,
117    zero_available: AtomicU64,
118    skipped_no_permit: AtomicU64,
119    skipped_rate_limit: AtomicU64,
120    skipped_cooldown: AtomicU64,
121    dispatches: AtomicU64,
122}
123
124impl TaskPool {
125    /// Create a new task pool with the given maximum capacity.
126    pub fn new(capacity: usize) -> Self {
127        Self {
128            queue: Mutex::new(VecDeque::new()),
129            delayed: Mutex::new(std::collections::BinaryHeap::new()),
130            capacity,
131            capacity_sem: Semaphore::new(capacity),
132            next_id: AtomicU64::new(1),
133            notify: Notify::new(),
134            completed: AtomicU64::new(0),
135            failed: AtomicU64::new(0),
136            requeued: AtomicU64::new(0),
137            zero_available: AtomicU64::new(0),
138            skipped_no_permit: AtomicU64::new(0),
139            skipped_rate_limit: AtomicU64::new(0),
140            skipped_cooldown: AtomicU64::new(0),
141            dispatches: AtomicU64::new(0),
142        }
143    }
144
145    // ── submit variants ─────────────────────────────────────────────────
146
147    /// Submit a request, **blocking** until the pool has capacity.
148    ///
149    /// Returns a [`TaskHandle`] whose `.await` blocks until a proxied response
150    /// is obtained.
151    pub async fn submit(&self, request: reqwest::Request) -> TaskHandle {
152        // Wait for a free slot.
153        let permit = self
154            .capacity_sem
155            .acquire()
156            .await
157            .expect("capacity semaphore closed");
158        permit.forget(); // we manually add_permits in mark_completed
159
160        self.enqueue(request)
161    }
162
163    /// Non-blocking submit.  Returns `Err(PoolFull)` when the pool is at capacity.
164    pub fn try_submit(&self, request: reqwest::Request) -> Result<TaskHandle, ScatterProxyError> {
165        let permit = self
166            .capacity_sem
167            .try_acquire()
168            .map_err(|_| ScatterProxyError::PoolFull {
169                capacity: self.capacity,
170            })?;
171        permit.forget();
172        Ok(self.enqueue(request))
173    }
174
175    /// Submit with a caller-side timeout on the *submission* itself.
176    ///
177    /// Blocks up to `timeout` waiting for pool capacity.  Returns
178    /// `Err(Timeout)` if the deadline elapses before a slot opens.
179    pub async fn submit_timeout(
180        &self,
181        request: reqwest::Request,
182        timeout: Duration,
183    ) -> Result<TaskHandle, ScatterProxyError> {
184        match tokio::time::timeout(timeout, self.submit(request)).await {
185            Ok(handle) => Ok(handle),
186            Err(_) => Err(ScatterProxyError::Timeout { elapsed: timeout }),
187        }
188    }
189
190    // ── batch variants ──────────────────────────────────────────────────
191
192    /// Submit a batch of requests, **blocking** until pool capacity is available
193    /// for each one (sequentially).
194    pub async fn submit_batch(&self, requests: Vec<reqwest::Request>) -> Vec<TaskHandle> {
195        let mut handles = Vec::with_capacity(requests.len());
196        for req in requests {
197            handles.push(self.submit(req).await);
198        }
199        handles
200    }
201
202    /// Non-blocking atomic batch submit.  If there is not enough room for the
203    /// **entire** batch, no tasks are added and `Err(PoolFull)` is returned.
204    pub fn try_submit_batch(
205        &self,
206        requests: Vec<reqwest::Request>,
207    ) -> Result<Vec<TaskHandle>, ScatterProxyError> {
208        let count = requests.len();
209        if count == 0 {
210            return Ok(Vec::new());
211        }
212
213        // Try to acquire `count` permits at once.
214        let permit = self
215            .capacity_sem
216            .try_acquire_many(count as u32)
217            .map_err(|_| ScatterProxyError::PoolFull {
218                capacity: self.capacity,
219            })?;
220        permit.forget();
221
222        let mut handles = Vec::with_capacity(count);
223        for req in requests {
224            handles.push(self.enqueue(req));
225        }
226        Ok(handles)
227    }
228
229    // ── internal ────────────────────────────────────────────────────────
230
231    /// Construct a `TaskEntry`, push it onto the queue, wake the scheduler,
232    /// and return the handle.  Caller **must** have already acquired a
233    /// capacity permit.
234    fn enqueue(&self, request: reqwest::Request) -> TaskHandle {
235        let host = request.url().host_str().unwrap_or("unknown").to_string();
236        let (tx, rx) = oneshot::channel();
237        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
238
239        let entry = TaskEntry {
240            id,
241            request,
242            host,
243            attempts: 0,
244            result_tx: Some(tx),
245            last_error: String::new(),
246        };
247
248        {
249            self.promote_ready_delayed();
250            let mut queue = self.queue.lock().unwrap();
251            queue.push_back(entry);
252        }
253
254        self.notify.notify_one();
255        TaskHandle {
256            rx: AsyncMutex::new(rx),
257        }
258    }
259
260    // ── scheduler helpers ───────────────────────────────────────────────
261
262    /// Pick the next eligible task from the front of the queue.
263    ///
264    /// Tasks whose host appears in `skip_hosts` (e.g. circuit-broken hosts) are
265    /// left in the queue.  Returns `None` when no eligible task is found.
266    pub(crate) fn promote_ready_delayed(&self) -> usize {
267        let now = Instant::now();
268        let mut delayed = self.delayed.lock().unwrap();
269        if delayed.is_empty() {
270            return 0;
271        }
272        let mut ready = Vec::new();
273        while let Some(std::cmp::Reverse(item)) = delayed.peek() {
274            if item.ready_at <= now {
275                let std::cmp::Reverse(item) = delayed.pop().expect("heap peeked item must pop");
276                ready.push(item.entry);
277            } else {
278                break;
279            }
280        }
281        drop(delayed);
282        if ready.is_empty() {
283            return 0;
284        }
285        let count = ready.len();
286        let mut queue = self.queue.lock().unwrap();
287        for entry in ready {
288            queue.push_back(entry);
289        }
290        count
291    }
292
293    pub(crate) fn next_delayed_ready_in(&self) -> Option<Duration> {
294        let delayed = self.delayed.lock().unwrap();
295        let now = Instant::now();
296        delayed
297            .peek()
298            .map(|d| d.0.ready_at.saturating_duration_since(now))
299    }
300
301    pub(crate) fn pick_next(&self, skip_hosts: &HashSet<String>) -> Option<TaskEntry> {
302        let mut queue = self.queue.lock().unwrap();
303        if skip_hosts.is_empty() {
304            return queue.pop_front();
305        }
306
307        let len = queue.len();
308        for _ in 0..len {
309            let entry = queue.pop_front()?;
310            if !skip_hosts.contains(&entry.host) {
311                return Some(entry);
312            }
313            queue.push_back(entry);
314        }
315
316        None
317    }
318
319    /// Push a task back to the tail of the queue for retry.
320    pub(crate) fn push_back(&self, entry: TaskEntry) {
321        self.requeued.fetch_add(1, Ordering::Relaxed);
322        {
323            let mut queue = self.queue.lock().unwrap();
324            queue.push_back(entry);
325        }
326        self.notify.notify_one();
327    }
328
329    pub(crate) fn push_delayed(&self, entry: TaskEntry, delay: Duration) {
330        self.requeued.fetch_add(1, Ordering::Relaxed);
331        {
332            let mut delayed = self.delayed.lock().unwrap();
333            delayed.push(std::cmp::Reverse(DelayedTask {
334                ready_at: Instant::now() + delay,
335                entry,
336            }));
337        }
338        self.notify.notify_one();
339    }
340
341    /// Number of tasks currently waiting in the queue.
342    pub fn pending_count(&self) -> usize {
343        let queue = self.queue.lock().unwrap();
344        queue.len()
345    }
346
347    pub fn delayed_count(&self) -> usize {
348        let delayed = self.delayed.lock().unwrap();
349        delayed.len()
350    }
351
352    /// Total number of tasks that completed successfully.
353    pub fn completed_count(&self) -> u64 {
354        self.completed.load(Ordering::Relaxed)
355    }
356
357    /// Increment the completed counter **and** release one capacity-semaphore
358    /// permit so that a blocked [`submit`](TaskPool::submit) can proceed.
359    pub(crate) fn mark_completed(&self) {
360        self.completed.fetch_add(1, Ordering::Relaxed);
361        self.capacity_sem.add_permits(1);
362    }
363
364    /// Increment the failed counter and release one capacity-semaphore permit.
365    /// Used for unrecoverable tasks (e.g. non-cloneable body).
366    pub(crate) fn mark_failed(&self) {
367        self.failed.fetch_add(1, Ordering::Relaxed);
368        self.capacity_sem.add_permits(1);
369    }
370
371    /// Total number of tasks that failed unrecoverably.
372    pub fn failed_count(&self) -> u64 {
373        self.failed.load(Ordering::Relaxed)
374    }
375    pub fn requeued_count(&self) -> u64 {
376        self.requeued.load(Ordering::Relaxed)
377    }
378
379    pub(crate) fn mark_zero_available(&self) {
380        self.zero_available.fetch_add(1, Ordering::Relaxed);
381    }
382
383    pub fn zero_available_count(&self) -> u64 {
384        self.zero_available.load(Ordering::Relaxed)
385    }
386
387    pub(crate) fn mark_skipped_no_permit(&self) {
388        self.skipped_no_permit.fetch_add(1, Ordering::Relaxed);
389    }
390
391    pub fn skipped_no_permit_count(&self) -> u64 {
392        self.skipped_no_permit.load(Ordering::Relaxed)
393    }
394
395    pub(crate) fn mark_skipped_rate_limit(&self) {
396        self.skipped_rate_limit.fetch_add(1, Ordering::Relaxed);
397    }
398
399    pub fn skipped_rate_limit_count(&self) -> u64 {
400        self.skipped_rate_limit.load(Ordering::Relaxed)
401    }
402
403    pub(crate) fn mark_skipped_cooldown(&self) {
404        self.skipped_cooldown.fetch_add(1, Ordering::Relaxed);
405    }
406
407    pub fn skipped_cooldown_count(&self) -> u64 {
408        self.skipped_cooldown.load(Ordering::Relaxed)
409    }
410
411    pub(crate) fn mark_dispatch(&self) {
412        self.dispatches.fetch_add(1, Ordering::Relaxed);
413    }
414
415    pub fn dispatch_count(&self) -> u64 {
416        self.dispatches.load(Ordering::Relaxed)
417    }
418
419    /// Wait until a task becomes available (a task is submitted or pushed back).
420    #[allow(dead_code)]
421    pub(crate) async fn notified(&self) {
422        self.notify.notified().await;
423    }
424}
425
426// ─── Tests ───────────────────────────────────────────────────────────────────
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    /// Build a trivial GET request for testing.
433    fn test_request() -> reqwest::Request {
434        reqwest::Client::new()
435            .get("http://example.com/test")
436            .build()
437            .unwrap()
438    }
439
440    // ── pool basics ─────────────────────────────────────────────────────
441
442    #[test]
443    fn new_pool_has_zero_pending() {
444        let pool = TaskPool::new(10);
445        assert_eq!(pool.pending_count(), 0);
446        assert_eq!(pool.delayed_count(), 0);
447        assert_eq!(pool.completed_count(), 0);
448    }
449
450    // ── try_submit ──────────────────────────────────────────────────────
451
452    #[test]
453    fn try_submit_increments_pending_count() {
454        let pool = TaskPool::new(10);
455        let _h1 = pool.try_submit(test_request()).unwrap();
456        let _h2 = pool.try_submit(test_request()).unwrap();
457        assert_eq!(pool.pending_count(), 2);
458    }
459
460    #[test]
461    fn try_submit_returns_pool_full_when_at_capacity() {
462        let pool = TaskPool::new(2);
463        let _h1 = pool.try_submit(test_request()).unwrap();
464        let _h2 = pool.try_submit(test_request()).unwrap();
465        let result = pool.try_submit(test_request());
466        assert!(result.is_err());
467        match result.unwrap_err() {
468            ScatterProxyError::PoolFull { capacity } => assert_eq!(capacity, 2),
469            other => panic!("expected PoolFull, got {other:?}"),
470        }
471    }
472
473    #[test]
474    fn try_submit_assigns_incrementing_ids() {
475        let pool = TaskPool::new(10);
476        let _h1 = pool.try_submit(test_request()).unwrap();
477        let _h2 = pool.try_submit(test_request()).unwrap();
478
479        let skip = HashSet::new();
480        let t1 = pool.pick_next(&skip).unwrap();
481        let t2 = pool.pick_next(&skip).unwrap();
482        assert!(t2.id > t1.id);
483    }
484
485    #[test]
486    fn try_submit_extracts_host_from_url() {
487        let pool = TaskPool::new(10);
488        let _h = pool.try_submit(test_request()).unwrap();
489        let skip = HashSet::new();
490        let task = pool.pick_next(&skip).unwrap();
491        assert_eq!(task.host, "example.com");
492    }
493
494    // ── try_submit_batch ────────────────────────────────────────────────
495
496    #[test]
497    fn try_submit_batch_adds_all_tasks() {
498        let pool = TaskPool::new(10);
499        let reqs = vec![test_request(), test_request(), test_request()];
500        let handles = pool.try_submit_batch(reqs).unwrap();
501        assert_eq!(handles.len(), 3);
502        assert_eq!(pool.pending_count(), 3);
503    }
504
505    #[test]
506    fn try_submit_batch_atomic_rejection_when_pool_full() {
507        let pool = TaskPool::new(2);
508        let reqs = vec![test_request(), test_request(), test_request()];
509        let result = pool.try_submit_batch(reqs);
510        assert!(result.is_err());
511        assert_eq!(pool.pending_count(), 0);
512    }
513
514    #[test]
515    fn try_submit_batch_empty_vec_is_ok() {
516        let pool = TaskPool::new(10);
517        let handles = pool.try_submit_batch(vec![]).unwrap();
518        assert!(handles.is_empty());
519    }
520
521    // ── async submit ────────────────────────────────────────────────────
522
523    #[tokio::test]
524    async fn submit_blocks_then_proceeds_after_mark_completed() {
525        let pool = std::sync::Arc::new(TaskPool::new(1));
526        // Fill the pool.
527        let _h1 = pool.try_submit(test_request()).unwrap();
528
529        let pool2 = pool.clone();
530        let join = tokio::spawn(async move {
531            // This should block because the pool is full.
532            let _handle = pool2.submit(test_request()).await;
533        });
534
535        // Give the spawned task a moment to park on the semaphore.
536        tokio::time::sleep(Duration::from_millis(50)).await;
537        assert_eq!(pool.pending_count(), 1); // only the first task
538
539        // Free a slot.
540        {
541            let skip = HashSet::new();
542            let _task = pool.pick_next(&skip).unwrap();
543            pool.mark_completed();
544        }
545
546        // Now the spawned submit should unblock.
547        join.await.unwrap();
548        assert_eq!(pool.pending_count(), 1);
549    }
550
551    #[tokio::test]
552    async fn submit_timeout_returns_err_on_expiry() {
553        let pool = TaskPool::new(1);
554        let _h1 = pool.try_submit(test_request()).unwrap();
555
556        let result = pool
557            .submit_timeout(test_request(), Duration::from_millis(50))
558            .await;
559        assert!(result.is_err());
560        match result.unwrap_err() {
561            ScatterProxyError::Timeout { elapsed } => {
562                assert_eq!(elapsed, Duration::from_millis(50));
563            }
564            other => panic!("expected Timeout, got {other:?}"),
565        }
566    }
567
568    #[tokio::test]
569    async fn submit_batch_processes_all() {
570        let pool = TaskPool::new(10);
571        let reqs = vec![test_request(), test_request()];
572        let handles = pool.submit_batch(reqs).await;
573        assert_eq!(handles.len(), 2);
574        assert_eq!(pool.pending_count(), 2);
575    }
576
577    // ── pick_next ───────────────────────────────────────────────────────
578
579    #[test]
580    fn pick_next_returns_fifo_order() {
581        let pool = TaskPool::new(10);
582        let _h1 = pool.try_submit(test_request()).unwrap();
583        let _h2 = pool.try_submit(test_request()).unwrap();
584
585        let skip = HashSet::new();
586        let t1 = pool.pick_next(&skip).unwrap();
587        let t2 = pool.pick_next(&skip).unwrap();
588        assert!(t1.id < t2.id);
589    }
590
591    #[test]
592    fn pick_next_skips_circuit_broken_hosts() {
593        let pool = TaskPool::new(10);
594        let _h1 = pool.try_submit(test_request()).unwrap(); // example.com
595
596        let mut skip = HashSet::new();
597        skip.insert("example.com".into());
598        assert!(pool.pick_next(&skip).is_none());
599    }
600
601    #[test]
602    fn pick_next_returns_none_when_all_hosts_skipped() {
603        let pool = TaskPool::new(10);
604        let _h1 = pool.try_submit(test_request()).unwrap();
605        let _h2 = pool.try_submit(test_request()).unwrap();
606
607        let mut skip = HashSet::new();
608        skip.insert("example.com".into());
609        assert!(pool.pick_next(&skip).is_none());
610        assert_eq!(pool.pending_count(), 2);
611    }
612
613    #[test]
614    fn pick_next_returns_none_when_empty() {
615        let pool = TaskPool::new(10);
616        let skip = HashSet::new();
617        assert!(pool.pick_next(&skip).is_none());
618    }
619
620    #[test]
621    fn pick_next_selects_first_non_skipped_preserves_order() {
622        let pool = TaskPool::new(10);
623        // Task 1: example.com
624        let _h1 = pool.try_submit(test_request()).unwrap();
625        // Task 2: other.com
626        let req2 = reqwest::Client::new()
627            .get("http://other.com/path")
628            .build()
629            .unwrap();
630        let _h2 = pool.try_submit(req2).unwrap();
631        // Task 3: example.com
632        let _h3 = pool.try_submit(test_request()).unwrap();
633
634        let mut skip = HashSet::new();
635        skip.insert("example.com".into());
636
637        let picked = pool.pick_next(&skip).unwrap();
638        assert_eq!(picked.host, "other.com");
639        assert_eq!(pool.pending_count(), 2);
640    }
641
642    // ── push_back ───────────────────────────────────────────────────────
643
644    #[test]
645    fn push_back_requeues_to_tail() {
646        let pool = TaskPool::new(10);
647        let _h1 = pool.try_submit(test_request()).unwrap();
648        let _h2 = pool.try_submit(test_request()).unwrap();
649
650        let skip = HashSet::new();
651        let t1 = pool.pick_next(&skip).unwrap();
652        let id1 = t1.id;
653        pool.push_back(t1);
654
655        // t1 should now be after t2.
656        let t2 = pool.pick_next(&skip).unwrap();
657        let re_t1 = pool.pick_next(&skip).unwrap();
658        assert!(t2.id < id1 || re_t1.id == id1);
659    }
660
661    // ── mark_completed ──────────────────────────────────────────────────
662
663    #[test]
664    fn delayed_task_promotes_when_ready() {
665        let pool = TaskPool::new(10);
666        let _ = pool.try_submit(test_request()).unwrap();
667        let skip = HashSet::new();
668        let task = pool.pick_next(&skip).unwrap();
669        pool.push_delayed(task, Duration::from_millis(10));
670        assert_eq!(pool.delayed_count(), 1);
671        std::thread::sleep(Duration::from_millis(20));
672        let promoted = pool.promote_ready_delayed();
673        assert_eq!(promoted, 1);
674        assert_eq!(pool.delayed_count(), 0);
675        assert_eq!(pool.pending_count(), 1);
676    }
677
678    #[test]
679    fn mark_completed_increments_counter() {
680        let pool = TaskPool::new(10);
681        pool.mark_completed();
682        assert_eq!(pool.completed_count(), 1);
683    }
684
685    // ── TaskHandle ──────────────────────────────────────────────────────
686
687    #[tokio::test]
688    async fn task_handle_receives_success() {
689        let pool = TaskPool::new(10);
690        let handle = pool.try_submit(test_request()).unwrap();
691
692        let skip = HashSet::new();
693        let mut task = pool.pick_next(&skip).unwrap();
694        if let Some(tx) = task.result_tx.take() {
695            let _ = tx.send(ScatterResponse {
696                status: StatusCode::OK,
697                headers: HeaderMap::new(),
698                body: Bytes::from_static(b"hello"),
699            });
700        }
701
702        let resp = handle
703            .with_timeout(Duration::from_secs(1))
704            .await
705            .unwrap()
706            .unwrap();
707        assert_eq!(resp.status, StatusCode::OK);
708        assert_eq!(resp.body.as_ref(), b"hello");
709    }
710
711    #[tokio::test]
712    async fn task_handle_returns_502_when_sender_dropped() {
713        let pool = TaskPool::new(10);
714        let handle = pool.try_submit(test_request()).unwrap();
715
716        // Pick and drop the task without sending a response.
717        let skip = HashSet::new();
718        let _task = pool.pick_next(&skip).unwrap();
719        drop(_task);
720
721        let resp = handle
722            .with_timeout(Duration::from_secs(1))
723            .await
724            .unwrap()
725            .unwrap();
726        assert_eq!(resp.status, StatusCode::BAD_GATEWAY);
727    }
728
729    #[tokio::test]
730    async fn task_handle_with_timeout_ok() {
731        let pool = TaskPool::new(10);
732        let handle = pool.try_submit(test_request()).unwrap();
733
734        let skip = HashSet::new();
735        let mut task = pool.pick_next(&skip).unwrap();
736        if let Some(tx) = task.result_tx.take() {
737            let _ = tx.send(ScatterResponse {
738                status: StatusCode::OK,
739                headers: HeaderMap::new(),
740                body: Bytes::from_static(b"ok"),
741            });
742        }
743
744        let resp = handle
745            .with_timeout(Duration::from_secs(5))
746            .await
747            .unwrap()
748            .unwrap();
749        assert_eq!(resp.status, StatusCode::OK);
750    }
751
752    #[tokio::test]
753    async fn task_handle_with_timeout_expires() {
754        let pool = TaskPool::new(10);
755        let handle = pool.try_submit(test_request()).unwrap();
756
757        let result = handle
758            .with_timeout(Duration::from_millis(50))
759            .await
760            .unwrap();
761        assert!(result.is_none());
762    }
763
764    // ── notified ────────────────────────────────────────────────────────
765
766    #[tokio::test]
767    async fn notified_wakes_on_try_submit() {
768        let pool = std::sync::Arc::new(TaskPool::new(10));
769        let pool2 = pool.clone();
770
771        let waiter = tokio::spawn(async move {
772            pool2.notified().await;
773            true
774        });
775
776        tokio::time::sleep(Duration::from_millis(20)).await;
777        let _h = pool.try_submit(test_request()).unwrap();
778
779        assert!(waiter.await.unwrap());
780    }
781
782    #[tokio::test]
783    async fn notified_wakes_on_push_back() {
784        let pool = std::sync::Arc::new(TaskPool::new(10));
785        let _h = pool.try_submit(test_request()).unwrap();
786
787        let skip = HashSet::new();
788        let task = pool.pick_next(&skip).unwrap();
789
790        let pool2 = pool.clone();
791        let waiter = tokio::spawn(async move {
792            pool2.notified().await;
793            true
794        });
795
796        tokio::time::sleep(Duration::from_millis(20)).await;
797        pool.push_back(task);
798
799        assert!(waiter.await.unwrap());
800    }
801
802    // ── edge cases ──────────────────────────────────────────────────────
803
804    #[test]
805    fn pool_with_zero_capacity_rejects_everything() {
806        let pool = TaskPool::new(0);
807        let result = pool.try_submit(test_request());
808        assert!(result.is_err());
809    }
810
811    #[test]
812    fn pool_allows_try_submit_after_mark_completed_frees_space() {
813        let pool = TaskPool::new(1);
814        let _h1 = pool.try_submit(test_request()).unwrap();
815        // Pool is full.
816        assert!(pool.try_submit(test_request()).is_err());
817
818        // Simulate task completion.
819        let skip = HashSet::new();
820        let _task = pool.pick_next(&skip).unwrap();
821        pool.mark_completed();
822
823        // Now there's room again.
824        let _h2 = pool.try_submit(test_request()).unwrap();
825    }
826
827    #[test]
828    fn task_entry_has_correct_defaults_on_try_submit() {
829        let pool = TaskPool::new(10);
830        let _h = pool.try_submit(test_request()).unwrap();
831
832        let skip = HashSet::new();
833        let task = pool.pick_next(&skip).unwrap();
834        assert_eq!(task.attempts, 0);
835        assert!(task.last_error.is_empty());
836        assert!(task.result_tx.is_some());
837    }
838
839    #[test]
840    fn scatter_response_debug() {
841        let resp = ScatterResponse {
842            status: StatusCode::OK,
843            headers: HeaderMap::new(),
844            body: Bytes::from_static(b"test"),
845        };
846        let dbg = format!("{resp:?}");
847        assert!(dbg.contains("200"));
848    }
849}