1use std::collections::HashSet;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Mutex;
7use std::task::{Context, Poll};
8use std::time::{Duration, Instant};
9
10use bytes::Bytes;
11use http::{HeaderMap, StatusCode};
12use tokio::sync::{oneshot, Notify};
13
14use crate::error::ScatterProxyError;
15
16#[derive(Debug)]
18pub struct ScatterResponse {
19 pub status: StatusCode,
20 pub headers: HeaderMap,
21 pub body: Bytes,
22}
23
24#[derive(Debug)]
27pub struct TaskHandle {
28 rx: oneshot::Receiver<Result<ScatterResponse, ScatterProxyError>>,
29}
30
31impl Future for TaskHandle {
32 type Output = Result<ScatterResponse, ScatterProxyError>;
33
34 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
35 Pin::new(&mut self.rx).poll(cx).map(|result| {
36 result.unwrap_or_else(|_| Err(ScatterProxyError::Init("task channel closed".into())))
37 })
38 }
39}
40
41pub(crate) struct TaskEntry {
43 #[allow(dead_code)]
44 pub id: u64,
45 pub request: reqwest::Request,
46 pub host: String,
47 pub attempts: usize,
48 pub max_attempts: usize,
49 pub submitted_at: Instant,
50 pub task_timeout: Duration,
51 pub result_tx: Option<oneshot::Sender<Result<ScatterResponse, ScatterProxyError>>>,
52 pub last_error: String,
53}
54
55pub struct TaskPool {
57 queue: Mutex<VecDeque<TaskEntry>>,
58 capacity: usize,
59 next_id: AtomicU64,
60 notify: Notify,
61 completed: AtomicU64,
62 failed: AtomicU64,
63}
64
65impl TaskPool {
66 pub fn new(capacity: usize) -> Self {
68 Self {
69 queue: Mutex::new(VecDeque::new()),
70 capacity,
71 next_id: AtomicU64::new(1),
72 notify: Notify::new(),
73 completed: AtomicU64::new(0),
74 failed: AtomicU64::new(0),
75 }
76 }
77
78 pub fn submit(
82 &self,
83 request: reqwest::Request,
84 max_attempts: usize,
85 task_timeout: Duration,
86 ) -> Result<TaskHandle, ScatterProxyError> {
87 let host = request.url().host_str().unwrap_or("unknown").to_string();
88
89 let (tx, rx) = oneshot::channel();
90 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
91
92 let entry = TaskEntry {
93 id,
94 request,
95 host,
96 attempts: 0,
97 max_attempts,
98 submitted_at: Instant::now(),
99 task_timeout,
100 result_tx: Some(tx),
101 last_error: String::new(),
102 };
103
104 {
105 let mut queue = self.queue.lock().unwrap();
106 if queue.len() >= self.capacity {
107 return Err(ScatterProxyError::PoolFull {
108 capacity: self.capacity,
109 });
110 }
111 queue.push_back(entry);
112 }
113
114 self.notify.notify_one();
115
116 Ok(TaskHandle { rx })
117 }
118
119 pub fn submit_batch(
124 &self,
125 requests: Vec<reqwest::Request>,
126 max_attempts: usize,
127 task_timeout: Duration,
128 ) -> Result<Vec<TaskHandle>, ScatterProxyError> {
129 let count = requests.len();
130
131 {
133 let queue = self.queue.lock().unwrap();
134 if queue.len() + count > self.capacity {
135 return Err(ScatterProxyError::PoolFull {
136 capacity: self.capacity,
137 });
138 }
139 }
140
141 let mut handles = Vec::with_capacity(count);
142
143 {
144 let mut queue = self.queue.lock().unwrap();
145
146 if queue.len() + count > self.capacity {
148 return Err(ScatterProxyError::PoolFull {
149 capacity: self.capacity,
150 });
151 }
152
153 for request in requests {
154 let host = request.url().host_str().unwrap_or("unknown").to_string();
155 let (tx, rx) = oneshot::channel();
156 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
157
158 let entry = TaskEntry {
159 id,
160 request,
161 host,
162 attempts: 0,
163 max_attempts,
164 submitted_at: Instant::now(),
165 task_timeout,
166 result_tx: Some(tx),
167 last_error: String::new(),
168 };
169 queue.push_back(entry);
170 handles.push(TaskHandle { rx });
171 }
172 }
173
174 for _ in 0..count {
176 self.notify.notify_one();
177 }
178
179 Ok(handles)
180 }
181
182 pub(crate) fn pick_next(&self, skip_hosts: &HashSet<String>) -> Option<TaskEntry> {
187 let mut queue = self.queue.lock().unwrap();
188 let len = queue.len();
189
190 for i in 0..len {
191 if let Some(entry) = queue.get(i) {
192 if !skip_hosts.contains(&entry.host) {
193 return queue.remove(i);
194 }
195 }
196 }
197
198 None
199 }
200
201 pub(crate) fn push_back(&self, entry: TaskEntry) {
203 {
204 let mut queue = self.queue.lock().unwrap();
205 queue.push_back(entry);
206 }
207 self.notify.notify_one();
208 }
209
210 pub fn pending_count(&self) -> usize {
212 let queue = self.queue.lock().unwrap();
213 queue.len()
214 }
215
216 pub fn completed_count(&self) -> u64 {
218 self.completed.load(Ordering::Relaxed)
219 }
220
221 pub fn failed_count(&self) -> u64 {
223 self.failed.load(Ordering::Relaxed)
224 }
225
226 pub(crate) fn mark_completed(&self) {
228 self.completed.fetch_add(1, Ordering::Relaxed);
229 }
230
231 pub(crate) fn mark_failed(&self) {
233 self.failed.fetch_add(1, Ordering::Relaxed);
234 }
235
236 #[allow(dead_code)]
238 pub(crate) async fn notified(&self) {
239 self.notify.notified().await;
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use std::collections::HashSet;
247 use std::time::Duration;
248
249 fn test_request(url: &str) -> reqwest::Request {
251 reqwest::Client::new().get(url).build().unwrap()
252 }
253
254 #[test]
259 fn new_pool_has_zero_pending() {
260 let pool = TaskPool::new(100);
261 assert_eq!(pool.pending_count(), 0);
262 assert_eq!(pool.completed_count(), 0);
263 assert_eq!(pool.failed_count(), 0);
264 }
265
266 #[test]
271 fn submit_increments_pending_count() {
272 let pool = TaskPool::new(10);
273 let _h = pool
274 .submit(
275 test_request("http://example.com"),
276 3,
277 Duration::from_secs(10),
278 )
279 .unwrap();
280 assert_eq!(pool.pending_count(), 1);
281 }
282
283 #[test]
284 fn submit_returns_pool_full_when_at_capacity() {
285 let pool = TaskPool::new(1);
286 let _h1 = pool
287 .submit(
288 test_request("http://example.com"),
289 3,
290 Duration::from_secs(10),
291 )
292 .unwrap();
293
294 let result = pool.submit(
295 test_request("http://example.com/2"),
296 3,
297 Duration::from_secs(10),
298 );
299 assert!(result.is_err());
300 match result.unwrap_err() {
301 ScatterProxyError::PoolFull { capacity } => assert_eq!(capacity, 1),
302 other => panic!("expected PoolFull, got: {other:?}"),
303 }
304 }
305
306 #[test]
307 fn submit_assigns_incrementing_ids() {
308 let pool = TaskPool::new(10);
309 let _h1 = pool
310 .submit(test_request("http://a.com"), 3, Duration::from_secs(10))
311 .unwrap();
312 let _h2 = pool
313 .submit(test_request("http://b.com"), 3, Duration::from_secs(10))
314 .unwrap();
315
316 let skip = HashSet::new();
317 let t1 = pool.pick_next(&skip).unwrap();
318 let t2 = pool.pick_next(&skip).unwrap();
319 assert!(t2.id > t1.id);
320 }
321
322 #[test]
323 fn submit_extracts_host_from_url() {
324 let pool = TaskPool::new(10);
325 let _h = pool
326 .submit(
327 test_request("http://myhost.example.com/path?q=1"),
328 3,
329 Duration::from_secs(10),
330 )
331 .unwrap();
332
333 let skip = HashSet::new();
334 let entry = pool.pick_next(&skip).unwrap();
335 assert_eq!(entry.host, "myhost.example.com");
336 }
337
338 #[test]
343 fn submit_batch_adds_all_tasks() {
344 let pool = TaskPool::new(10);
345 let reqs = vec![
346 test_request("http://a.com"),
347 test_request("http://b.com"),
348 test_request("http://c.com"),
349 ];
350 let handles = pool.submit_batch(reqs, 3, Duration::from_secs(10)).unwrap();
351 assert_eq!(handles.len(), 3);
352 assert_eq!(pool.pending_count(), 3);
353 }
354
355 #[test]
356 fn submit_batch_atomic_rejection_when_pool_full() {
357 let pool = TaskPool::new(2);
358 let _h = pool
359 .submit(test_request("http://x.com"), 3, Duration::from_secs(10))
360 .unwrap();
361
362 let reqs = vec![test_request("http://a.com"), test_request("http://b.com")];
364 let result = pool.submit_batch(reqs, 3, Duration::from_secs(10));
365 assert!(result.is_err());
366 assert_eq!(pool.pending_count(), 1);
368 }
369
370 #[test]
371 fn submit_batch_empty_vec_is_ok() {
372 let pool = TaskPool::new(10);
373 let handles = pool
374 .submit_batch(vec![], 3, Duration::from_secs(10))
375 .unwrap();
376 assert!(handles.is_empty());
377 assert_eq!(pool.pending_count(), 0);
378 }
379
380 #[test]
385 fn pick_next_returns_fifo_order() {
386 let pool = TaskPool::new(10);
387 let _h1 = pool
388 .submit(test_request("http://first.com"), 3, Duration::from_secs(10))
389 .unwrap();
390 let _h2 = pool
391 .submit(
392 test_request("http://second.com"),
393 3,
394 Duration::from_secs(10),
395 )
396 .unwrap();
397
398 let skip = HashSet::new();
399 let t = pool.pick_next(&skip).unwrap();
400 assert_eq!(t.host, "first.com");
401 assert_eq!(pool.pending_count(), 1);
402 }
403
404 #[test]
405 fn pick_next_skips_circuit_broken_hosts() {
406 let pool = TaskPool::new(10);
407 let _h1 = pool
408 .submit(
409 test_request("http://broken.com/a"),
410 3,
411 Duration::from_secs(10),
412 )
413 .unwrap();
414 let _h2 = pool
415 .submit(test_request("http://ok.com/b"), 3, Duration::from_secs(10))
416 .unwrap();
417
418 let mut skip = HashSet::new();
419 skip.insert("broken.com".to_string());
420
421 let t = pool.pick_next(&skip).unwrap();
422 assert_eq!(t.host, "ok.com");
423 assert_eq!(pool.pending_count(), 1);
425 }
426
427 #[test]
428 fn pick_next_returns_none_when_all_hosts_skipped() {
429 let pool = TaskPool::new(10);
430 let _h = pool
431 .submit(
432 test_request("http://broken.com"),
433 3,
434 Duration::from_secs(10),
435 )
436 .unwrap();
437
438 let mut skip = HashSet::new();
439 skip.insert("broken.com".to_string());
440
441 assert!(pool.pick_next(&skip).is_none());
442 assert_eq!(pool.pending_count(), 1);
443 }
444
445 #[test]
446 fn pick_next_returns_none_when_empty() {
447 let pool = TaskPool::new(10);
448 let skip = HashSet::new();
449 assert!(pool.pick_next(&skip).is_none());
450 }
451
452 #[test]
457 fn push_back_requeues_to_tail() {
458 let pool = TaskPool::new(10);
459 let _h1 = pool
460 .submit(test_request("http://first.com"), 3, Duration::from_secs(10))
461 .unwrap();
462 let _h2 = pool
463 .submit(
464 test_request("http://second.com"),
465 3,
466 Duration::from_secs(10),
467 )
468 .unwrap();
469
470 let skip = HashSet::new();
471 let mut entry = pool.pick_next(&skip).unwrap();
472 assert_eq!(entry.host, "first.com");
473 entry.attempts += 1;
474 entry.last_error = "connection refused".into();
475
476 pool.push_back(entry);
477 assert_eq!(pool.pending_count(), 2);
478
479 let t = pool.pick_next(&skip).unwrap();
481 assert_eq!(t.host, "second.com");
482
483 let t = pool.pick_next(&skip).unwrap();
485 assert_eq!(t.host, "first.com");
486 assert_eq!(t.attempts, 1);
487 assert_eq!(t.last_error, "connection refused");
488 }
489
490 #[test]
495 fn mark_completed_increments_counter() {
496 let pool = TaskPool::new(10);
497 pool.mark_completed();
498 pool.mark_completed();
499 assert_eq!(pool.completed_count(), 2);
500 }
501
502 #[test]
503 fn mark_failed_increments_counter() {
504 let pool = TaskPool::new(10);
505 pool.mark_failed();
506 assert_eq!(pool.failed_count(), 1);
507 }
508
509 #[tokio::test]
514 async fn task_handle_receives_success() {
515 let pool = TaskPool::new(10);
516 let handle = pool
517 .submit(
518 test_request("http://example.com"),
519 3,
520 Duration::from_secs(10),
521 )
522 .unwrap();
523
524 let skip = HashSet::new();
525 let entry = pool.pick_next(&skip).unwrap();
526
527 let response = ScatterResponse {
528 status: StatusCode::OK,
529 headers: HeaderMap::new(),
530 body: Bytes::from("hello"),
531 };
532
533 entry.result_tx.unwrap().send(Ok(response)).unwrap();
534
535 let result = handle.await;
536 assert!(result.is_ok());
537 let resp = result.unwrap();
538 assert_eq!(resp.status, StatusCode::OK);
539 assert_eq!(resp.body, Bytes::from("hello"));
540 }
541
542 #[tokio::test]
543 async fn task_handle_receives_error() {
544 let pool = TaskPool::new(10);
545 let handle = pool
546 .submit(
547 test_request("http://example.com"),
548 3,
549 Duration::from_secs(10),
550 )
551 .unwrap();
552
553 let skip = HashSet::new();
554 let entry = pool.pick_next(&skip).unwrap();
555
556 entry
557 .result_tx
558 .unwrap()
559 .send(Err(ScatterProxyError::MaxAttemptsExhausted {
560 host: "example.com".into(),
561 attempts: 3,
562 last_error: "timeout".into(),
563 }))
564 .unwrap();
565
566 let result = handle.await;
567 assert!(result.is_err());
568 }
569
570 #[tokio::test]
571 async fn task_handle_returns_error_when_sender_dropped() {
572 let pool = TaskPool::new(10);
573 let handle = pool
574 .submit(
575 test_request("http://example.com"),
576 3,
577 Duration::from_secs(10),
578 )
579 .unwrap();
580
581 let skip = HashSet::new();
583 let entry = pool.pick_next(&skip).unwrap();
584 drop(entry);
585
586 let result = handle.await;
587 assert!(result.is_err());
588 match result.unwrap_err() {
589 ScatterProxyError::Init(msg) => assert!(msg.contains("channel closed")),
590 other => panic!("expected Init, got: {other:?}"),
591 }
592 }
593
594 #[tokio::test]
599 async fn notified_wakes_on_submit() {
600 let pool = std::sync::Arc::new(TaskPool::new(10));
601 let pool2 = pool.clone();
602
603 let waiter = tokio::spawn(async move {
604 pool2.notified().await;
605 true
606 });
607
608 tokio::time::sleep(Duration::from_millis(10)).await;
610
611 let _h = pool
612 .submit(
613 test_request("http://example.com"),
614 3,
615 Duration::from_secs(10),
616 )
617 .unwrap();
618
619 let woke = tokio::time::timeout(Duration::from_secs(1), waiter)
620 .await
621 .unwrap()
622 .unwrap();
623 assert!(woke);
624 }
625
626 #[tokio::test]
627 async fn notified_wakes_on_push_back() {
628 let pool = std::sync::Arc::new(TaskPool::new(10));
629
630 let _h = pool
632 .submit(
633 test_request("http://example.com"),
634 3,
635 Duration::from_secs(10),
636 )
637 .unwrap();
638 let skip = HashSet::new();
639 let entry = pool.pick_next(&skip).unwrap();
640
641 let pool2 = pool.clone();
642 let waiter = tokio::spawn(async move {
643 pool2.notified().await;
644 true
645 });
646
647 tokio::time::sleep(Duration::from_millis(10)).await;
648
649 pool.push_back(entry);
650
651 let woke = tokio::time::timeout(Duration::from_secs(1), waiter)
652 .await
653 .unwrap()
654 .unwrap();
655 assert!(woke);
656 }
657
658 #[test]
663 fn pool_with_zero_capacity_rejects_everything() {
664 let pool = TaskPool::new(0);
665 let result = pool.submit(test_request("http://a.com"), 1, Duration::from_secs(5));
666 assert!(matches!(
667 result,
668 Err(ScatterProxyError::PoolFull { capacity: 0 })
669 ));
670 }
671
672 #[test]
673 fn pool_allows_submit_after_pick_frees_space() {
674 let pool = TaskPool::new(1);
675 let _h1 = pool
676 .submit(test_request("http://a.com"), 1, Duration::from_secs(5))
677 .unwrap();
678
679 assert!(pool
681 .submit(test_request("http://b.com"), 1, Duration::from_secs(5))
682 .is_err());
683
684 let skip = HashSet::new();
686 let _entry = pool.pick_next(&skip).unwrap();
687
688 let _h2 = pool
690 .submit(test_request("http://c.com"), 1, Duration::from_secs(5))
691 .unwrap();
692 assert_eq!(pool.pending_count(), 1);
693 }
694
695 #[test]
700 fn task_entry_has_correct_defaults_on_submit() {
701 let pool = TaskPool::new(10);
702 let timeout = Duration::from_secs(42);
703 let _h = pool
704 .submit(test_request("http://host.example.com/path"), 7, timeout)
705 .unwrap();
706
707 let skip = HashSet::new();
708 let entry = pool.pick_next(&skip).unwrap();
709
710 assert_eq!(entry.host, "host.example.com");
711 assert_eq!(entry.attempts, 0);
712 assert_eq!(entry.max_attempts, 7);
713 assert_eq!(entry.task_timeout, timeout);
714 assert!(entry.last_error.is_empty());
715 assert!(entry.result_tx.is_some());
716 }
717
718 #[test]
723 fn scatter_response_debug() {
724 let resp = ScatterResponse {
725 status: StatusCode::NOT_FOUND,
726 headers: HeaderMap::new(),
727 body: Bytes::from("not found"),
728 };
729 let dbg = format!("{resp:?}");
730 assert!(dbg.contains("404"));
731 }
732
733 #[test]
738 fn pick_next_selects_first_non_skipped_preserves_order() {
739 let pool = TaskPool::new(10);
740 let _h1 = pool
741 .submit(test_request("http://a.com/1"), 1, Duration::from_secs(5))
742 .unwrap();
743 let _h2 = pool
744 .submit(test_request("http://b.com/2"), 1, Duration::from_secs(5))
745 .unwrap();
746 let _h3 = pool
747 .submit(test_request("http://a.com/3"), 1, Duration::from_secs(5))
748 .unwrap();
749 let _h4 = pool
750 .submit(test_request("http://c.com/4"), 1, Duration::from_secs(5))
751 .unwrap();
752
753 let mut skip = HashSet::new();
754 skip.insert("a.com".to_string());
755
756 let t1 = pool.pick_next(&skip).unwrap();
758 assert_eq!(t1.host, "b.com");
759
760 let t2 = pool.pick_next(&skip).unwrap();
761 assert_eq!(t2.host, "c.com");
762
763 assert_eq!(pool.pending_count(), 2);
765 assert!(pool.pick_next(&skip).is_none());
766 }
767}