Skip to main content

scatter_proxy/
task.rs

1use std::collections::HashSet;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Mutex;
7use std::task::{Context, Poll};
8use std::time::{Duration, Instant};
9
10use bytes::Bytes;
11use http::{HeaderMap, StatusCode};
12use tokio::sync::{oneshot, Notify};
13
14use crate::error::ScatterProxyError;
15
16/// Response from a successful proxied request.
17#[derive(Debug)]
18pub struct ScatterResponse {
19    pub status: StatusCode,
20    pub headers: HeaderMap,
21    pub body: Bytes,
22}
23
24/// Handle returned to the caller when a task is submitted.
25/// Implements `Future` so the caller can `.await` the proxied result.
26#[derive(Debug)]
27pub struct TaskHandle {
28    rx: oneshot::Receiver<Result<ScatterResponse, ScatterProxyError>>,
29}
30
31impl Future for TaskHandle {
32    type Output = Result<ScatterResponse, ScatterProxyError>;
33
34    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
35        Pin::new(&mut self.rx).poll(cx).map(|result| {
36            result.unwrap_or_else(|_| Err(ScatterProxyError::Init("task channel closed".into())))
37        })
38    }
39}
40
41/// Internal task entry stored in the pool queue.
42pub(crate) struct TaskEntry {
43    #[allow(dead_code)]
44    pub id: u64,
45    pub request: reqwest::Request,
46    pub host: String,
47    pub attempts: usize,
48    pub max_attempts: usize,
49    pub submitted_at: Instant,
50    pub task_timeout: Duration,
51    pub result_tx: Option<oneshot::Sender<Result<ScatterResponse, ScatterProxyError>>>,
52    pub last_error: String,
53}
54
55/// Thread-safe task pool with FIFO ordering and bounded capacity.
56pub struct TaskPool {
57    queue: Mutex<VecDeque<TaskEntry>>,
58    capacity: usize,
59    next_id: AtomicU64,
60    notify: Notify,
61    completed: AtomicU64,
62    failed: AtomicU64,
63}
64
65impl TaskPool {
66    /// Create a new task pool with the given maximum capacity.
67    pub fn new(capacity: usize) -> Self {
68        Self {
69            queue: Mutex::new(VecDeque::new()),
70            capacity,
71            next_id: AtomicU64::new(1),
72            notify: Notify::new(),
73            completed: AtomicU64::new(0),
74            failed: AtomicU64::new(0),
75        }
76    }
77
78    /// Submit a single task. Returns a `TaskHandle` for await-ing the result.
79    ///
80    /// Returns `Err(PoolFull)` when the queue is already at capacity.
81    pub fn submit(
82        &self,
83        request: reqwest::Request,
84        max_attempts: usize,
85        task_timeout: Duration,
86    ) -> Result<TaskHandle, ScatterProxyError> {
87        let host = request.url().host_str().unwrap_or("unknown").to_string();
88
89        let (tx, rx) = oneshot::channel();
90        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
91
92        let entry = TaskEntry {
93            id,
94            request,
95            host,
96            attempts: 0,
97            max_attempts,
98            submitted_at: Instant::now(),
99            task_timeout,
100            result_tx: Some(tx),
101            last_error: String::new(),
102        };
103
104        {
105            let mut queue = self.queue.lock().unwrap();
106            if queue.len() >= self.capacity {
107                return Err(ScatterProxyError::PoolFull {
108                    capacity: self.capacity,
109                });
110            }
111            queue.push_back(entry);
112        }
113
114        self.notify.notify_one();
115
116        Ok(TaskHandle { rx })
117    }
118
119    /// Submit multiple tasks at once. Returns a `TaskHandle` per request.
120    ///
121    /// If the pool doesn't have room for the entire batch, no tasks are added
122    /// and `Err(PoolFull)` is returned.
123    pub fn submit_batch(
124        &self,
125        requests: Vec<reqwest::Request>,
126        max_attempts: usize,
127        task_timeout: Duration,
128    ) -> Result<Vec<TaskHandle>, ScatterProxyError> {
129        let count = requests.len();
130
131        // Pre-check capacity so the batch is atomic.
132        {
133            let queue = self.queue.lock().unwrap();
134            if queue.len() + count > self.capacity {
135                return Err(ScatterProxyError::PoolFull {
136                    capacity: self.capacity,
137                });
138            }
139        }
140
141        let mut handles = Vec::with_capacity(count);
142
143        {
144            let mut queue = self.queue.lock().unwrap();
145
146            // Double-check after re-acquiring the lock.
147            if queue.len() + count > self.capacity {
148                return Err(ScatterProxyError::PoolFull {
149                    capacity: self.capacity,
150                });
151            }
152
153            for request in requests {
154                let host = request.url().host_str().unwrap_or("unknown").to_string();
155                let (tx, rx) = oneshot::channel();
156                let id = self.next_id.fetch_add(1, Ordering::Relaxed);
157
158                let entry = TaskEntry {
159                    id,
160                    request,
161                    host,
162                    attempts: 0,
163                    max_attempts,
164                    submitted_at: Instant::now(),
165                    task_timeout,
166                    result_tx: Some(tx),
167                    last_error: String::new(),
168                };
169                queue.push_back(entry);
170                handles.push(TaskHandle { rx });
171            }
172        }
173
174        // Wake the scheduler — one notification per task added.
175        for _ in 0..count {
176            self.notify.notify_one();
177        }
178
179        Ok(handles)
180    }
181
182    /// Pick the next eligible task from the front of the queue.
183    ///
184    /// Tasks whose host appears in `skip_hosts` (e.g. circuit-broken hosts) are
185    /// left in the queue. Returns `None` when no eligible task is found.
186    pub(crate) fn pick_next(&self, skip_hosts: &HashSet<String>) -> Option<TaskEntry> {
187        let mut queue = self.queue.lock().unwrap();
188        let len = queue.len();
189
190        for i in 0..len {
191            if let Some(entry) = queue.get(i) {
192                if !skip_hosts.contains(&entry.host) {
193                    return queue.remove(i);
194                }
195            }
196        }
197
198        None
199    }
200
201    /// Push a failed task back to the tail of the queue for retry.
202    pub(crate) fn push_back(&self, entry: TaskEntry) {
203        {
204            let mut queue = self.queue.lock().unwrap();
205            queue.push_back(entry);
206        }
207        self.notify.notify_one();
208    }
209
210    /// Number of tasks currently waiting in the queue.
211    pub fn pending_count(&self) -> usize {
212        let queue = self.queue.lock().unwrap();
213        queue.len()
214    }
215
216    /// Total number of tasks that completed successfully.
217    pub fn completed_count(&self) -> u64 {
218        self.completed.load(Ordering::Relaxed)
219    }
220
221    /// Total number of tasks that failed permanently.
222    pub fn failed_count(&self) -> u64 {
223        self.failed.load(Ordering::Relaxed)
224    }
225
226    /// Increment the completed counter.
227    pub(crate) fn mark_completed(&self) {
228        self.completed.fetch_add(1, Ordering::Relaxed);
229    }
230
231    /// Increment the failed counter.
232    pub(crate) fn mark_failed(&self) {
233        self.failed.fetch_add(1, Ordering::Relaxed);
234    }
235
236    /// Wait until a task becomes available (a task is submitted or pushed back).
237    #[allow(dead_code)]
238    pub(crate) async fn notified(&self) {
239        self.notify.notified().await;
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use std::collections::HashSet;
247    use std::time::Duration;
248
249    /// Helper: build a GET request to the given URL.
250    fn test_request(url: &str) -> reqwest::Request {
251        reqwest::Client::new().get(url).build().unwrap()
252    }
253
254    // -----------------------------------------------------------------------
255    // TaskPool::new
256    // -----------------------------------------------------------------------
257
258    #[test]
259    fn new_pool_has_zero_pending() {
260        let pool = TaskPool::new(100);
261        assert_eq!(pool.pending_count(), 0);
262        assert_eq!(pool.completed_count(), 0);
263        assert_eq!(pool.failed_count(), 0);
264    }
265
266    // -----------------------------------------------------------------------
267    // submit
268    // -----------------------------------------------------------------------
269
270    #[test]
271    fn submit_increments_pending_count() {
272        let pool = TaskPool::new(10);
273        let _h = pool
274            .submit(
275                test_request("http://example.com"),
276                3,
277                Duration::from_secs(10),
278            )
279            .unwrap();
280        assert_eq!(pool.pending_count(), 1);
281    }
282
283    #[test]
284    fn submit_returns_pool_full_when_at_capacity() {
285        let pool = TaskPool::new(1);
286        let _h1 = pool
287            .submit(
288                test_request("http://example.com"),
289                3,
290                Duration::from_secs(10),
291            )
292            .unwrap();
293
294        let result = pool.submit(
295            test_request("http://example.com/2"),
296            3,
297            Duration::from_secs(10),
298        );
299        assert!(result.is_err());
300        match result.unwrap_err() {
301            ScatterProxyError::PoolFull { capacity } => assert_eq!(capacity, 1),
302            other => panic!("expected PoolFull, got: {other:?}"),
303        }
304    }
305
306    #[test]
307    fn submit_assigns_incrementing_ids() {
308        let pool = TaskPool::new(10);
309        let _h1 = pool
310            .submit(test_request("http://a.com"), 3, Duration::from_secs(10))
311            .unwrap();
312        let _h2 = pool
313            .submit(test_request("http://b.com"), 3, Duration::from_secs(10))
314            .unwrap();
315
316        let skip = HashSet::new();
317        let t1 = pool.pick_next(&skip).unwrap();
318        let t2 = pool.pick_next(&skip).unwrap();
319        assert!(t2.id > t1.id);
320    }
321
322    #[test]
323    fn submit_extracts_host_from_url() {
324        let pool = TaskPool::new(10);
325        let _h = pool
326            .submit(
327                test_request("http://myhost.example.com/path?q=1"),
328                3,
329                Duration::from_secs(10),
330            )
331            .unwrap();
332
333        let skip = HashSet::new();
334        let entry = pool.pick_next(&skip).unwrap();
335        assert_eq!(entry.host, "myhost.example.com");
336    }
337
338    // -----------------------------------------------------------------------
339    // submit_batch
340    // -----------------------------------------------------------------------
341
342    #[test]
343    fn submit_batch_adds_all_tasks() {
344        let pool = TaskPool::new(10);
345        let reqs = vec![
346            test_request("http://a.com"),
347            test_request("http://b.com"),
348            test_request("http://c.com"),
349        ];
350        let handles = pool.submit_batch(reqs, 3, Duration::from_secs(10)).unwrap();
351        assert_eq!(handles.len(), 3);
352        assert_eq!(pool.pending_count(), 3);
353    }
354
355    #[test]
356    fn submit_batch_atomic_rejection_when_pool_full() {
357        let pool = TaskPool::new(2);
358        let _h = pool
359            .submit(test_request("http://x.com"), 3, Duration::from_secs(10))
360            .unwrap();
361
362        // 1 already in pool, capacity 2, trying to add 2 more → should fail
363        let reqs = vec![test_request("http://a.com"), test_request("http://b.com")];
364        let result = pool.submit_batch(reqs, 3, Duration::from_secs(10));
365        assert!(result.is_err());
366        // Original task should still be there
367        assert_eq!(pool.pending_count(), 1);
368    }
369
370    #[test]
371    fn submit_batch_empty_vec_is_ok() {
372        let pool = TaskPool::new(10);
373        let handles = pool
374            .submit_batch(vec![], 3, Duration::from_secs(10))
375            .unwrap();
376        assert!(handles.is_empty());
377        assert_eq!(pool.pending_count(), 0);
378    }
379
380    // -----------------------------------------------------------------------
381    // pick_next
382    // -----------------------------------------------------------------------
383
384    #[test]
385    fn pick_next_returns_fifo_order() {
386        let pool = TaskPool::new(10);
387        let _h1 = pool
388            .submit(test_request("http://first.com"), 3, Duration::from_secs(10))
389            .unwrap();
390        let _h2 = pool
391            .submit(
392                test_request("http://second.com"),
393                3,
394                Duration::from_secs(10),
395            )
396            .unwrap();
397
398        let skip = HashSet::new();
399        let t = pool.pick_next(&skip).unwrap();
400        assert_eq!(t.host, "first.com");
401        assert_eq!(pool.pending_count(), 1);
402    }
403
404    #[test]
405    fn pick_next_skips_circuit_broken_hosts() {
406        let pool = TaskPool::new(10);
407        let _h1 = pool
408            .submit(
409                test_request("http://broken.com/a"),
410                3,
411                Duration::from_secs(10),
412            )
413            .unwrap();
414        let _h2 = pool
415            .submit(test_request("http://ok.com/b"), 3, Duration::from_secs(10))
416            .unwrap();
417
418        let mut skip = HashSet::new();
419        skip.insert("broken.com".to_string());
420
421        let t = pool.pick_next(&skip).unwrap();
422        assert_eq!(t.host, "ok.com");
423        // broken.com task is still in the queue
424        assert_eq!(pool.pending_count(), 1);
425    }
426
427    #[test]
428    fn pick_next_returns_none_when_all_hosts_skipped() {
429        let pool = TaskPool::new(10);
430        let _h = pool
431            .submit(
432                test_request("http://broken.com"),
433                3,
434                Duration::from_secs(10),
435            )
436            .unwrap();
437
438        let mut skip = HashSet::new();
439        skip.insert("broken.com".to_string());
440
441        assert!(pool.pick_next(&skip).is_none());
442        assert_eq!(pool.pending_count(), 1);
443    }
444
445    #[test]
446    fn pick_next_returns_none_when_empty() {
447        let pool = TaskPool::new(10);
448        let skip = HashSet::new();
449        assert!(pool.pick_next(&skip).is_none());
450    }
451
452    // -----------------------------------------------------------------------
453    // push_back
454    // -----------------------------------------------------------------------
455
456    #[test]
457    fn push_back_requeues_to_tail() {
458        let pool = TaskPool::new(10);
459        let _h1 = pool
460            .submit(test_request("http://first.com"), 3, Duration::from_secs(10))
461            .unwrap();
462        let _h2 = pool
463            .submit(
464                test_request("http://second.com"),
465                3,
466                Duration::from_secs(10),
467            )
468            .unwrap();
469
470        let skip = HashSet::new();
471        let mut entry = pool.pick_next(&skip).unwrap();
472        assert_eq!(entry.host, "first.com");
473        entry.attempts += 1;
474        entry.last_error = "connection refused".into();
475
476        pool.push_back(entry);
477        assert_eq!(pool.pending_count(), 2);
478
479        // second.com should come first now
480        let t = pool.pick_next(&skip).unwrap();
481        assert_eq!(t.host, "second.com");
482
483        // then re-queued first.com
484        let t = pool.pick_next(&skip).unwrap();
485        assert_eq!(t.host, "first.com");
486        assert_eq!(t.attempts, 1);
487        assert_eq!(t.last_error, "connection refused");
488    }
489
490    // -----------------------------------------------------------------------
491    // mark_completed / mark_failed
492    // -----------------------------------------------------------------------
493
494    #[test]
495    fn mark_completed_increments_counter() {
496        let pool = TaskPool::new(10);
497        pool.mark_completed();
498        pool.mark_completed();
499        assert_eq!(pool.completed_count(), 2);
500    }
501
502    #[test]
503    fn mark_failed_increments_counter() {
504        let pool = TaskPool::new(10);
505        pool.mark_failed();
506        assert_eq!(pool.failed_count(), 1);
507    }
508
509    // -----------------------------------------------------------------------
510    // TaskHandle as Future
511    // -----------------------------------------------------------------------
512
513    #[tokio::test]
514    async fn task_handle_receives_success() {
515        let pool = TaskPool::new(10);
516        let handle = pool
517            .submit(
518                test_request("http://example.com"),
519                3,
520                Duration::from_secs(10),
521            )
522            .unwrap();
523
524        let skip = HashSet::new();
525        let entry = pool.pick_next(&skip).unwrap();
526
527        let response = ScatterResponse {
528            status: StatusCode::OK,
529            headers: HeaderMap::new(),
530            body: Bytes::from("hello"),
531        };
532
533        entry.result_tx.unwrap().send(Ok(response)).unwrap();
534
535        let result = handle.await;
536        assert!(result.is_ok());
537        let resp = result.unwrap();
538        assert_eq!(resp.status, StatusCode::OK);
539        assert_eq!(resp.body, Bytes::from("hello"));
540    }
541
542    #[tokio::test]
543    async fn task_handle_receives_error() {
544        let pool = TaskPool::new(10);
545        let handle = pool
546            .submit(
547                test_request("http://example.com"),
548                3,
549                Duration::from_secs(10),
550            )
551            .unwrap();
552
553        let skip = HashSet::new();
554        let entry = pool.pick_next(&skip).unwrap();
555
556        entry
557            .result_tx
558            .unwrap()
559            .send(Err(ScatterProxyError::MaxAttemptsExhausted {
560                host: "example.com".into(),
561                attempts: 3,
562                last_error: "timeout".into(),
563            }))
564            .unwrap();
565
566        let result = handle.await;
567        assert!(result.is_err());
568    }
569
570    #[tokio::test]
571    async fn task_handle_returns_error_when_sender_dropped() {
572        let pool = TaskPool::new(10);
573        let handle = pool
574            .submit(
575                test_request("http://example.com"),
576                3,
577                Duration::from_secs(10),
578            )
579            .unwrap();
580
581        // Pick and drop the entry (which drops the sender).
582        let skip = HashSet::new();
583        let entry = pool.pick_next(&skip).unwrap();
584        drop(entry);
585
586        let result = handle.await;
587        assert!(result.is_err());
588        match result.unwrap_err() {
589            ScatterProxyError::Init(msg) => assert!(msg.contains("channel closed")),
590            other => panic!("expected Init, got: {other:?}"),
591        }
592    }
593
594    // -----------------------------------------------------------------------
595    // notified
596    // -----------------------------------------------------------------------
597
598    #[tokio::test]
599    async fn notified_wakes_on_submit() {
600        let pool = std::sync::Arc::new(TaskPool::new(10));
601        let pool2 = pool.clone();
602
603        let waiter = tokio::spawn(async move {
604            pool2.notified().await;
605            true
606        });
607
608        // Give the waiter a moment to park.
609        tokio::time::sleep(Duration::from_millis(10)).await;
610
611        let _h = pool
612            .submit(
613                test_request("http://example.com"),
614                3,
615                Duration::from_secs(10),
616            )
617            .unwrap();
618
619        let woke = tokio::time::timeout(Duration::from_secs(1), waiter)
620            .await
621            .unwrap()
622            .unwrap();
623        assert!(woke);
624    }
625
626    #[tokio::test]
627    async fn notified_wakes_on_push_back() {
628        let pool = std::sync::Arc::new(TaskPool::new(10));
629
630        // Submit and pick a task to get an entry we can push_back.
631        let _h = pool
632            .submit(
633                test_request("http://example.com"),
634                3,
635                Duration::from_secs(10),
636            )
637            .unwrap();
638        let skip = HashSet::new();
639        let entry = pool.pick_next(&skip).unwrap();
640
641        let pool2 = pool.clone();
642        let waiter = tokio::spawn(async move {
643            pool2.notified().await;
644            true
645        });
646
647        tokio::time::sleep(Duration::from_millis(10)).await;
648
649        pool.push_back(entry);
650
651        let woke = tokio::time::timeout(Duration::from_secs(1), waiter)
652            .await
653            .unwrap()
654            .unwrap();
655        assert!(woke);
656    }
657
658    // -----------------------------------------------------------------------
659    // Capacity edge cases
660    // -----------------------------------------------------------------------
661
662    #[test]
663    fn pool_with_zero_capacity_rejects_everything() {
664        let pool = TaskPool::new(0);
665        let result = pool.submit(test_request("http://a.com"), 1, Duration::from_secs(5));
666        assert!(matches!(
667            result,
668            Err(ScatterProxyError::PoolFull { capacity: 0 })
669        ));
670    }
671
672    #[test]
673    fn pool_allows_submit_after_pick_frees_space() {
674        let pool = TaskPool::new(1);
675        let _h1 = pool
676            .submit(test_request("http://a.com"), 1, Duration::from_secs(5))
677            .unwrap();
678
679        // Pool full now.
680        assert!(pool
681            .submit(test_request("http://b.com"), 1, Duration::from_secs(5))
682            .is_err());
683
684        // Pick the task, freeing a slot.
685        let skip = HashSet::new();
686        let _entry = pool.pick_next(&skip).unwrap();
687
688        // Now we can submit again.
689        let _h2 = pool
690            .submit(test_request("http://c.com"), 1, Duration::from_secs(5))
691            .unwrap();
692        assert_eq!(pool.pending_count(), 1);
693    }
694
695    // -----------------------------------------------------------------------
696    // TaskEntry fields
697    // -----------------------------------------------------------------------
698
699    #[test]
700    fn task_entry_has_correct_defaults_on_submit() {
701        let pool = TaskPool::new(10);
702        let timeout = Duration::from_secs(42);
703        let _h = pool
704            .submit(test_request("http://host.example.com/path"), 7, timeout)
705            .unwrap();
706
707        let skip = HashSet::new();
708        let entry = pool.pick_next(&skip).unwrap();
709
710        assert_eq!(entry.host, "host.example.com");
711        assert_eq!(entry.attempts, 0);
712        assert_eq!(entry.max_attempts, 7);
713        assert_eq!(entry.task_timeout, timeout);
714        assert!(entry.last_error.is_empty());
715        assert!(entry.result_tx.is_some());
716    }
717
718    // -----------------------------------------------------------------------
719    // ScatterResponse basics
720    // -----------------------------------------------------------------------
721
722    #[test]
723    fn scatter_response_debug() {
724        let resp = ScatterResponse {
725            status: StatusCode::NOT_FOUND,
726            headers: HeaderMap::new(),
727            body: Bytes::from("not found"),
728        };
729        let dbg = format!("{resp:?}");
730        assert!(dbg.contains("404"));
731    }
732
733    // -----------------------------------------------------------------------
734    // Multiple hosts with partial skipping
735    // -----------------------------------------------------------------------
736
737    #[test]
738    fn pick_next_selects_first_non_skipped_preserves_order() {
739        let pool = TaskPool::new(10);
740        let _h1 = pool
741            .submit(test_request("http://a.com/1"), 1, Duration::from_secs(5))
742            .unwrap();
743        let _h2 = pool
744            .submit(test_request("http://b.com/2"), 1, Duration::from_secs(5))
745            .unwrap();
746        let _h3 = pool
747            .submit(test_request("http://a.com/3"), 1, Duration::from_secs(5))
748            .unwrap();
749        let _h4 = pool
750            .submit(test_request("http://c.com/4"), 1, Duration::from_secs(5))
751            .unwrap();
752
753        let mut skip = HashSet::new();
754        skip.insert("a.com".to_string());
755
756        // Should skip both a.com entries, return b.com first
757        let t1 = pool.pick_next(&skip).unwrap();
758        assert_eq!(t1.host, "b.com");
759
760        let t2 = pool.pick_next(&skip).unwrap();
761        assert_eq!(t2.host, "c.com");
762
763        // Only a.com entries remain
764        assert_eq!(pool.pending_count(), 2);
765        assert!(pool.pick_next(&skip).is_none());
766    }
767}