Skip to main content

scatter_proxy/
task.rs

1use std::collections::HashSet;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Mutex;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use bytes::Bytes;
11use http::{HeaderMap, StatusCode};
12use tokio::sync::{oneshot, Notify, Semaphore};
13
14use crate::error::ScatterProxyError;
15
16/// Response from a successful proxied request.
17#[derive(Debug)]
18pub struct ScatterResponse {
19    pub status: StatusCode,
20    pub headers: HeaderMap,
21    pub body: Bytes,
22}
23
24/// Handle returned to the caller when a task is submitted.
25///
26/// Implements `Future<Output = ScatterResponse>` — awaiting it blocks until
27/// the scheduler delivers a successful response.  It will **never** resolve
28/// to an error; the scheduler retries internally forever.
29///
30/// To add a caller-side deadline use [`TaskHandle::with_timeout`].
31#[derive(Debug)]
32pub struct TaskHandle {
33    rx: oneshot::Receiver<ScatterResponse>,
34}
35
36impl TaskHandle {
37    /// Await the result with a caller-side timeout.
38    ///
39    /// Returns `Ok(response)` if the task completes in time, or
40    /// `Err(ScatterProxyError::Timeout)` if the deadline elapses.
41    ///
42    /// ```ignore
43    /// let resp = handle.with_timeout(Duration::from_secs(30)).await?;
44    /// ```
45    pub async fn with_timeout(
46        self,
47        duration: Duration,
48    ) -> Result<ScatterResponse, ScatterProxyError> {
49        match tokio::time::timeout(duration, self).await {
50            Ok(resp) => Ok(resp),
51            Err(_) => Err(ScatterProxyError::Timeout { elapsed: duration }),
52        }
53    }
54}
55
56impl Future for TaskHandle {
57    type Output = ScatterResponse;
58
59    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
60        match Pin::new(&mut self.rx).poll(cx) {
61            Poll::Ready(Ok(resp)) => Poll::Ready(resp),
62            Poll::Ready(Err(_)) => {
63                // Channel closed without sending — should not happen in normal
64                // operation.  Return a synthetic 502 so the caller never panics.
65                Poll::Ready(ScatterResponse {
66                    status: StatusCode::BAD_GATEWAY,
67                    headers: HeaderMap::new(),
68                    body: Bytes::from_static(
69                        b"scatter-proxy: internal error - task channel closed",
70                    ),
71                })
72            }
73            Poll::Pending => Poll::Pending,
74        }
75    }
76}
77
78/// Internal task entry stored in the pool queue.
79pub(crate) struct TaskEntry {
80    #[allow(dead_code)]
81    pub id: u64,
82    pub request: reqwest::Request,
83    pub host: String,
84    /// How many scheduling rounds this task has been through (for logging).
85    pub attempts: usize,
86    /// Sender half — the scheduler sends the final `ScatterResponse` here.
87    pub result_tx: Option<oneshot::Sender<ScatterResponse>>,
88    /// Description of the last failure (for debug logging).
89    pub last_error: String,
90}
91
92/// Thread-safe, bounded task pool with async back-pressure.
93///
94/// *   [`submit`](TaskPool::submit) — blocks until capacity is available.
95/// *   [`try_submit`](TaskPool::try_submit) — returns immediately with `Err(PoolFull)` when full.
96/// *   [`submit_timeout`](TaskPool::submit_timeout) — blocks up to a deadline.
97pub struct TaskPool {
98    queue: Mutex<VecDeque<TaskEntry>>,
99    capacity: usize,
100    /// Semaphore that tracks how many *task slots* are free.
101    /// A permit is acquired on submit and released when the task permanently
102    /// leaves the system (success delivered to the caller).
103    capacity_sem: Semaphore,
104    next_id: AtomicU64,
105    /// Wakes the scheduler when new work is enqueued.
106    notify: Notify,
107    completed: AtomicU64,
108    failed: AtomicU64,
109}
110
111impl TaskPool {
112    /// Create a new task pool with the given maximum capacity.
113    pub fn new(capacity: usize) -> Self {
114        Self {
115            queue: Mutex::new(VecDeque::new()),
116            capacity,
117            capacity_sem: Semaphore::new(capacity),
118            next_id: AtomicU64::new(1),
119            notify: Notify::new(),
120            completed: AtomicU64::new(0),
121            failed: AtomicU64::new(0),
122        }
123    }
124
125    // ── submit variants ─────────────────────────────────────────────────
126
127    /// Submit a request, **blocking** until the pool has capacity.
128    ///
129    /// Returns a [`TaskHandle`] whose `.await` blocks until a proxied response
130    /// is obtained.
131    pub async fn submit(&self, request: reqwest::Request) -> TaskHandle {
132        // Wait for a free slot.
133        let permit = self
134            .capacity_sem
135            .acquire()
136            .await
137            .expect("capacity semaphore closed");
138        permit.forget(); // we manually add_permits in mark_completed
139
140        self.enqueue(request)
141    }
142
143    /// Non-blocking submit.  Returns `Err(PoolFull)` when the pool is at capacity.
144    pub fn try_submit(&self, request: reqwest::Request) -> Result<TaskHandle, ScatterProxyError> {
145        let permit = self
146            .capacity_sem
147            .try_acquire()
148            .map_err(|_| ScatterProxyError::PoolFull {
149                capacity: self.capacity,
150            })?;
151        permit.forget();
152        Ok(self.enqueue(request))
153    }
154
155    /// Submit with a caller-side timeout on the *submission* itself.
156    ///
157    /// Blocks up to `timeout` waiting for pool capacity.  Returns
158    /// `Err(Timeout)` if the deadline elapses before a slot opens.
159    pub async fn submit_timeout(
160        &self,
161        request: reqwest::Request,
162        timeout: Duration,
163    ) -> Result<TaskHandle, ScatterProxyError> {
164        match tokio::time::timeout(timeout, self.submit(request)).await {
165            Ok(handle) => Ok(handle),
166            Err(_) => Err(ScatterProxyError::Timeout { elapsed: timeout }),
167        }
168    }
169
170    // ── batch variants ──────────────────────────────────────────────────
171
172    /// Submit a batch of requests, **blocking** until pool capacity is available
173    /// for each one (sequentially).
174    pub async fn submit_batch(&self, requests: Vec<reqwest::Request>) -> Vec<TaskHandle> {
175        let mut handles = Vec::with_capacity(requests.len());
176        for req in requests {
177            handles.push(self.submit(req).await);
178        }
179        handles
180    }
181
182    /// Non-blocking atomic batch submit.  If there is not enough room for the
183    /// **entire** batch, no tasks are added and `Err(PoolFull)` is returned.
184    pub fn try_submit_batch(
185        &self,
186        requests: Vec<reqwest::Request>,
187    ) -> Result<Vec<TaskHandle>, ScatterProxyError> {
188        let count = requests.len();
189        if count == 0 {
190            return Ok(Vec::new());
191        }
192
193        // Try to acquire `count` permits at once.
194        let permit = self
195            .capacity_sem
196            .try_acquire_many(count as u32)
197            .map_err(|_| ScatterProxyError::PoolFull {
198                capacity: self.capacity,
199            })?;
200        permit.forget();
201
202        let mut handles = Vec::with_capacity(count);
203        for req in requests {
204            handles.push(self.enqueue(req));
205        }
206        Ok(handles)
207    }
208
209    // ── internal ────────────────────────────────────────────────────────
210
211    /// Construct a `TaskEntry`, push it onto the queue, wake the scheduler,
212    /// and return the handle.  Caller **must** have already acquired a
213    /// capacity permit.
214    fn enqueue(&self, request: reqwest::Request) -> TaskHandle {
215        let host = request.url().host_str().unwrap_or("unknown").to_string();
216        let (tx, rx) = oneshot::channel();
217        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
218
219        let entry = TaskEntry {
220            id,
221            request,
222            host,
223            attempts: 0,
224            result_tx: Some(tx),
225            last_error: String::new(),
226        };
227
228        {
229            let mut queue = self.queue.lock().unwrap();
230            queue.push_back(entry);
231        }
232
233        self.notify.notify_one();
234        TaskHandle { rx }
235    }
236
237    // ── scheduler helpers ───────────────────────────────────────────────
238
239    /// Pick the next eligible task from the front of the queue.
240    ///
241    /// Tasks whose host appears in `skip_hosts` (e.g. circuit-broken hosts) are
242    /// left in the queue.  Returns `None` when no eligible task is found.
243    pub(crate) fn pick_next(&self, skip_hosts: &HashSet<String>) -> Option<TaskEntry> {
244        let mut queue = self.queue.lock().unwrap();
245        let len = queue.len();
246
247        for i in 0..len {
248            if let Some(entry) = queue.get(i) {
249                if !skip_hosts.contains(&entry.host) {
250                    return queue.remove(i);
251                }
252            }
253        }
254
255        None
256    }
257
258    /// Push a task back to the tail of the queue for retry.
259    pub(crate) fn push_back(&self, entry: TaskEntry) {
260        {
261            let mut queue = self.queue.lock().unwrap();
262            queue.push_back(entry);
263        }
264        self.notify.notify_one();
265    }
266
267    /// Number of tasks currently waiting in the queue.
268    pub fn pending_count(&self) -> usize {
269        let queue = self.queue.lock().unwrap();
270        queue.len()
271    }
272
273    /// Total number of tasks that completed successfully.
274    pub fn completed_count(&self) -> u64 {
275        self.completed.load(Ordering::Relaxed)
276    }
277
278    /// Increment the completed counter **and** release one capacity-semaphore
279    /// permit so that a blocked [`submit`](TaskPool::submit) can proceed.
280    pub(crate) fn mark_completed(&self) {
281        self.completed.fetch_add(1, Ordering::Relaxed);
282        self.capacity_sem.add_permits(1);
283    }
284
285    /// Increment the failed counter and release one capacity-semaphore permit.
286    /// Used for unrecoverable tasks (e.g. non-cloneable body).
287    pub(crate) fn mark_failed(&self) {
288        self.failed.fetch_add(1, Ordering::Relaxed);
289        self.capacity_sem.add_permits(1);
290    }
291
292    /// Total number of tasks that failed unrecoverably.
293    pub fn failed_count(&self) -> u64 {
294        self.failed.load(Ordering::Relaxed)
295    }
296
297    /// Wait until a task becomes available (a task is submitted or pushed back).
298    #[allow(dead_code)]
299    pub(crate) async fn notified(&self) {
300        self.notify.notified().await;
301    }
302}
303
304// ─── Tests ───────────────────────────────────────────────────────────────────
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    /// Build a trivial GET request for testing.
311    fn test_request() -> reqwest::Request {
312        reqwest::Client::new()
313            .get("http://example.com/test")
314            .build()
315            .unwrap()
316    }
317
318    // ── pool basics ─────────────────────────────────────────────────────
319
320    #[test]
321    fn new_pool_has_zero_pending() {
322        let pool = TaskPool::new(10);
323        assert_eq!(pool.pending_count(), 0);
324        assert_eq!(pool.completed_count(), 0);
325    }
326
327    // ── try_submit ──────────────────────────────────────────────────────
328
329    #[test]
330    fn try_submit_increments_pending_count() {
331        let pool = TaskPool::new(10);
332        let _h1 = pool.try_submit(test_request()).unwrap();
333        let _h2 = pool.try_submit(test_request()).unwrap();
334        assert_eq!(pool.pending_count(), 2);
335    }
336
337    #[test]
338    fn try_submit_returns_pool_full_when_at_capacity() {
339        let pool = TaskPool::new(2);
340        let _h1 = pool.try_submit(test_request()).unwrap();
341        let _h2 = pool.try_submit(test_request()).unwrap();
342        let result = pool.try_submit(test_request());
343        assert!(result.is_err());
344        match result.unwrap_err() {
345            ScatterProxyError::PoolFull { capacity } => assert_eq!(capacity, 2),
346            other => panic!("expected PoolFull, got {other:?}"),
347        }
348    }
349
350    #[test]
351    fn try_submit_assigns_incrementing_ids() {
352        let pool = TaskPool::new(10);
353        let _h1 = pool.try_submit(test_request()).unwrap();
354        let _h2 = pool.try_submit(test_request()).unwrap();
355
356        let skip = HashSet::new();
357        let t1 = pool.pick_next(&skip).unwrap();
358        let t2 = pool.pick_next(&skip).unwrap();
359        assert!(t2.id > t1.id);
360    }
361
362    #[test]
363    fn try_submit_extracts_host_from_url() {
364        let pool = TaskPool::new(10);
365        let _h = pool.try_submit(test_request()).unwrap();
366        let skip = HashSet::new();
367        let task = pool.pick_next(&skip).unwrap();
368        assert_eq!(task.host, "example.com");
369    }
370
371    // ── try_submit_batch ────────────────────────────────────────────────
372
373    #[test]
374    fn try_submit_batch_adds_all_tasks() {
375        let pool = TaskPool::new(10);
376        let reqs = vec![test_request(), test_request(), test_request()];
377        let handles = pool.try_submit_batch(reqs).unwrap();
378        assert_eq!(handles.len(), 3);
379        assert_eq!(pool.pending_count(), 3);
380    }
381
382    #[test]
383    fn try_submit_batch_atomic_rejection_when_pool_full() {
384        let pool = TaskPool::new(2);
385        let reqs = vec![test_request(), test_request(), test_request()];
386        let result = pool.try_submit_batch(reqs);
387        assert!(result.is_err());
388        assert_eq!(pool.pending_count(), 0);
389    }
390
391    #[test]
392    fn try_submit_batch_empty_vec_is_ok() {
393        let pool = TaskPool::new(10);
394        let handles = pool.try_submit_batch(vec![]).unwrap();
395        assert!(handles.is_empty());
396    }
397
398    // ── async submit ────────────────────────────────────────────────────
399
400    #[tokio::test]
401    async fn submit_blocks_then_proceeds_after_mark_completed() {
402        let pool = std::sync::Arc::new(TaskPool::new(1));
403        // Fill the pool.
404        let _h1 = pool.try_submit(test_request()).unwrap();
405
406        let pool2 = pool.clone();
407        let join = tokio::spawn(async move {
408            // This should block because the pool is full.
409            let _handle = pool2.submit(test_request()).await;
410        });
411
412        // Give the spawned task a moment to park on the semaphore.
413        tokio::time::sleep(Duration::from_millis(50)).await;
414        assert_eq!(pool.pending_count(), 1); // only the first task
415
416        // Free a slot.
417        {
418            let skip = HashSet::new();
419            let _task = pool.pick_next(&skip).unwrap();
420            pool.mark_completed();
421        }
422
423        // Now the spawned submit should unblock.
424        join.await.unwrap();
425        assert_eq!(pool.pending_count(), 1);
426    }
427
428    #[tokio::test]
429    async fn submit_timeout_returns_err_on_expiry() {
430        let pool = TaskPool::new(1);
431        let _h1 = pool.try_submit(test_request()).unwrap();
432
433        let result = pool
434            .submit_timeout(test_request(), Duration::from_millis(50))
435            .await;
436        assert!(result.is_err());
437        match result.unwrap_err() {
438            ScatterProxyError::Timeout { elapsed } => {
439                assert_eq!(elapsed, Duration::from_millis(50));
440            }
441            other => panic!("expected Timeout, got {other:?}"),
442        }
443    }
444
445    #[tokio::test]
446    async fn submit_batch_processes_all() {
447        let pool = TaskPool::new(10);
448        let reqs = vec![test_request(), test_request()];
449        let handles = pool.submit_batch(reqs).await;
450        assert_eq!(handles.len(), 2);
451        assert_eq!(pool.pending_count(), 2);
452    }
453
454    // ── pick_next ───────────────────────────────────────────────────────
455
456    #[test]
457    fn pick_next_returns_fifo_order() {
458        let pool = TaskPool::new(10);
459        let _h1 = pool.try_submit(test_request()).unwrap();
460        let _h2 = pool.try_submit(test_request()).unwrap();
461
462        let skip = HashSet::new();
463        let t1 = pool.pick_next(&skip).unwrap();
464        let t2 = pool.pick_next(&skip).unwrap();
465        assert!(t1.id < t2.id);
466    }
467
468    #[test]
469    fn pick_next_skips_circuit_broken_hosts() {
470        let pool = TaskPool::new(10);
471        let _h1 = pool.try_submit(test_request()).unwrap(); // example.com
472
473        let mut skip = HashSet::new();
474        skip.insert("example.com".into());
475        assert!(pool.pick_next(&skip).is_none());
476    }
477
478    #[test]
479    fn pick_next_returns_none_when_all_hosts_skipped() {
480        let pool = TaskPool::new(10);
481        let _h1 = pool.try_submit(test_request()).unwrap();
482        let _h2 = pool.try_submit(test_request()).unwrap();
483
484        let mut skip = HashSet::new();
485        skip.insert("example.com".into());
486        assert!(pool.pick_next(&skip).is_none());
487        assert_eq!(pool.pending_count(), 2);
488    }
489
490    #[test]
491    fn pick_next_returns_none_when_empty() {
492        let pool = TaskPool::new(10);
493        let skip = HashSet::new();
494        assert!(pool.pick_next(&skip).is_none());
495    }
496
497    #[test]
498    fn pick_next_selects_first_non_skipped_preserves_order() {
499        let pool = TaskPool::new(10);
500        // Task 1: example.com
501        let _h1 = pool.try_submit(test_request()).unwrap();
502        // Task 2: other.com
503        let req2 = reqwest::Client::new()
504            .get("http://other.com/path")
505            .build()
506            .unwrap();
507        let _h2 = pool.try_submit(req2).unwrap();
508        // Task 3: example.com
509        let _h3 = pool.try_submit(test_request()).unwrap();
510
511        let mut skip = HashSet::new();
512        skip.insert("example.com".into());
513
514        let picked = pool.pick_next(&skip).unwrap();
515        assert_eq!(picked.host, "other.com");
516        assert_eq!(pool.pending_count(), 2);
517    }
518
519    // ── push_back ───────────────────────────────────────────────────────
520
521    #[test]
522    fn push_back_requeues_to_tail() {
523        let pool = TaskPool::new(10);
524        let _h1 = pool.try_submit(test_request()).unwrap();
525        let _h2 = pool.try_submit(test_request()).unwrap();
526
527        let skip = HashSet::new();
528        let t1 = pool.pick_next(&skip).unwrap();
529        let id1 = t1.id;
530        pool.push_back(t1);
531
532        // t1 should now be after t2.
533        let t2 = pool.pick_next(&skip).unwrap();
534        let re_t1 = pool.pick_next(&skip).unwrap();
535        assert!(t2.id < id1 || re_t1.id == id1);
536    }
537
538    // ── mark_completed ──────────────────────────────────────────────────
539
540    #[test]
541    fn mark_completed_increments_counter() {
542        let pool = TaskPool::new(10);
543        pool.mark_completed();
544        assert_eq!(pool.completed_count(), 1);
545    }
546
547    // ── TaskHandle ──────────────────────────────────────────────────────
548
549    #[tokio::test]
550    async fn task_handle_receives_success() {
551        let pool = TaskPool::new(10);
552        let handle = pool.try_submit(test_request()).unwrap();
553
554        let skip = HashSet::new();
555        let mut task = pool.pick_next(&skip).unwrap();
556        if let Some(tx) = task.result_tx.take() {
557            let _ = tx.send(ScatterResponse {
558                status: StatusCode::OK,
559                headers: HeaderMap::new(),
560                body: Bytes::from_static(b"hello"),
561            });
562        }
563
564        let resp = handle.await;
565        assert_eq!(resp.status, StatusCode::OK);
566        assert_eq!(resp.body.as_ref(), b"hello");
567    }
568
569    #[tokio::test]
570    async fn task_handle_returns_502_when_sender_dropped() {
571        let pool = TaskPool::new(10);
572        let handle = pool.try_submit(test_request()).unwrap();
573
574        // Pick and drop the task without sending a response.
575        let skip = HashSet::new();
576        let _task = pool.pick_next(&skip).unwrap();
577        drop(_task);
578
579        let resp = handle.await;
580        assert_eq!(resp.status, StatusCode::BAD_GATEWAY);
581    }
582
583    #[tokio::test]
584    async fn task_handle_with_timeout_ok() {
585        let pool = TaskPool::new(10);
586        let handle = pool.try_submit(test_request()).unwrap();
587
588        let skip = HashSet::new();
589        let mut task = pool.pick_next(&skip).unwrap();
590        if let Some(tx) = task.result_tx.take() {
591            let _ = tx.send(ScatterResponse {
592                status: StatusCode::OK,
593                headers: HeaderMap::new(),
594                body: Bytes::from_static(b"ok"),
595            });
596        }
597
598        let resp = handle.with_timeout(Duration::from_secs(5)).await.unwrap();
599        assert_eq!(resp.status, StatusCode::OK);
600    }
601
602    #[tokio::test]
603    async fn task_handle_with_timeout_expires() {
604        let pool = TaskPool::new(10);
605        let handle = pool.try_submit(test_request()).unwrap();
606
607        let result = handle.with_timeout(Duration::from_millis(50)).await;
608        assert!(result.is_err());
609    }
610
611    // ── notified ────────────────────────────────────────────────────────
612
613    #[tokio::test]
614    async fn notified_wakes_on_try_submit() {
615        let pool = std::sync::Arc::new(TaskPool::new(10));
616        let pool2 = pool.clone();
617
618        let waiter = tokio::spawn(async move {
619            pool2.notified().await;
620            true
621        });
622
623        tokio::time::sleep(Duration::from_millis(20)).await;
624        let _h = pool.try_submit(test_request()).unwrap();
625
626        assert!(waiter.await.unwrap());
627    }
628
629    #[tokio::test]
630    async fn notified_wakes_on_push_back() {
631        let pool = std::sync::Arc::new(TaskPool::new(10));
632        let _h = pool.try_submit(test_request()).unwrap();
633
634        let skip = HashSet::new();
635        let task = pool.pick_next(&skip).unwrap();
636
637        let pool2 = pool.clone();
638        let waiter = tokio::spawn(async move {
639            pool2.notified().await;
640            true
641        });
642
643        tokio::time::sleep(Duration::from_millis(20)).await;
644        pool.push_back(task);
645
646        assert!(waiter.await.unwrap());
647    }
648
649    // ── edge cases ──────────────────────────────────────────────────────
650
651    #[test]
652    fn pool_with_zero_capacity_rejects_everything() {
653        let pool = TaskPool::new(0);
654        let result = pool.try_submit(test_request());
655        assert!(result.is_err());
656    }
657
658    #[test]
659    fn pool_allows_try_submit_after_mark_completed_frees_space() {
660        let pool = TaskPool::new(1);
661        let _h1 = pool.try_submit(test_request()).unwrap();
662        // Pool is full.
663        assert!(pool.try_submit(test_request()).is_err());
664
665        // Simulate task completion.
666        let skip = HashSet::new();
667        let _task = pool.pick_next(&skip).unwrap();
668        pool.mark_completed();
669
670        // Now there's room again.
671        let _h2 = pool.try_submit(test_request()).unwrap();
672    }
673
674    #[test]
675    fn task_entry_has_correct_defaults_on_try_submit() {
676        let pool = TaskPool::new(10);
677        let _h = pool.try_submit(test_request()).unwrap();
678
679        let skip = HashSet::new();
680        let task = pool.pick_next(&skip).unwrap();
681        assert_eq!(task.attempts, 0);
682        assert!(task.last_error.is_empty());
683        assert!(task.result_tx.is_some());
684    }
685
686    #[test]
687    fn scatter_response_debug() {
688        let resp = ScatterResponse {
689            status: StatusCode::OK,
690            headers: HeaderMap::new(),
691            body: Bytes::from_static(b"test"),
692        };
693        let dbg = format!("{resp:?}");
694        assert!(dbg.contains("200"));
695    }
696}