1use core::hash::Hash;
22use core::time::Duration;
23use std::collections::HashMap;
24use std::sync::atomic::{AtomicBool, Ordering};
25use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
26
27use clock_lib::{Clock, Monotonic, SystemClock};
28use tokio::sync::Notify;
29
30use crate::decision::Decision;
31use crate::error::ThrottleError;
32use crate::limiter::Limiter;
33
34#[non_exhaustive]
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum Overflow {
40 #[default]
42 Reject,
43 DropOldest,
45 DropLowestPriority,
48}
49
50struct Waiter<K> {
52 seq: u64,
54 priority: u32,
56 deadline_ms: Option<u64>,
58 key: K,
60 evicted: Arc<AtomicBool>,
63}
64
65struct State<K> {
67 waiters: HashMap<u64, Waiter<K>>,
68 service_seq: u64,
70 next_seq: u64,
72 last_served: HashMap<K, u64>,
74}
75
76impl<K: Eq + Hash + Clone> State<K> {
77 fn new() -> Self {
78 Self {
79 waiters: HashMap::new(),
80 service_seq: 0,
81 next_seq: 0,
82 last_served: HashMap::new(),
83 }
84 }
85
86 fn prune_expired(&mut self, now_ms: u64) {
89 self.waiters
90 .retain(|_, w| w.deadline_ms.is_none_or(|d| now_ms < d));
91 }
92
93 fn winner(&self, now_ms: u64) -> Option<u64> {
96 self.waiters
97 .iter()
98 .filter(|(_, w)| w.deadline_ms.is_none_or(|d| now_ms < d))
99 .min_by(|(_, a), (_, b)| {
100 b.priority
101 .cmp(&a.priority) .then_with(|| self.recency(&a.key).cmp(&self.recency(&b.key)))
103 .then_with(|| a.seq.cmp(&b.seq))
104 })
105 .map(|(&id, _)| id)
106 }
107
108 fn recency(&self, key: &K) -> u64 {
110 self.last_served.get(key).copied().unwrap_or(0)
111 }
112
113 fn serve(&mut self, id: u64) {
115 if let Some(w) = self.waiters.remove(&id) {
116 self.service_seq += 1;
117 let _ = self.last_served.insert(w.key, self.service_seq);
118 }
119 }
120
121 fn insert(
123 &mut self,
124 priority: u32,
125 deadline_ms: Option<u64>,
126 key: K,
127 ) -> (u64, Arc<AtomicBool>) {
128 let id = self.next_seq;
129 self.next_seq += 1;
130 let evicted = Arc::new(AtomicBool::new(false));
131 let _ = self.waiters.insert(
132 id,
133 Waiter {
134 seq: id,
135 priority,
136 deadline_ms,
137 key,
138 evicted: Arc::clone(&evicted),
139 },
140 );
141 (id, evicted)
142 }
143
144 fn oldest(&self) -> Option<u64> {
146 self.waiters
147 .iter()
148 .min_by_key(|(_, w)| w.seq)
149 .map(|(&id, _)| id)
150 }
151
152 fn weakest(&self) -> Option<(u64, u32)> {
155 self.waiters
156 .iter()
157 .min_by(|(_, a), (_, b)| a.priority.cmp(&b.priority).then_with(|| b.seq.cmp(&a.seq)))
158 .map(|(&id, w)| (id, w.priority))
159 }
160}
161
162pub struct Queue<L, K = (), C = SystemClock>
189where
190 K: Eq + Hash + Clone + Send + Sync,
191 C: Clock,
192{
193 inner: L,
194 state: Mutex<State<K>>,
195 notify: Notify,
196 capacity: usize,
197 overflow: Overflow,
198 clock: C,
199 epoch: Monotonic,
200}
201
202impl Queue<core::convert::Infallible, ()> {
205 #[must_use]
207 pub fn builder() -> QueueBuilder {
208 QueueBuilder::new()
209 }
210}
211
212impl<L, K, C> Queue<L, K, C>
213where
214 L: Limiter,
215 K: Eq + Hash + Clone + Send + Sync,
216 C: Clock + Clone,
217{
218 fn new(inner: L, capacity: usize, overflow: Overflow, clock: C) -> Self {
219 let epoch = clock.now();
220 Self {
221 inner,
222 state: Mutex::new(State::new()),
223 notify: Notify::new(),
224 capacity: capacity.max(1),
225 overflow,
226 clock,
227 epoch,
228 }
229 }
230
231 #[must_use]
234 pub fn with_clock<C2>(self, clock: C2) -> Queue<L, K, C2>
235 where
236 C2: Clock + Clone,
237 {
238 Queue::new(self.inner, self.capacity, self.overflow, clock)
239 }
240
241 #[must_use]
243 pub fn len(&self) -> usize {
244 self.lock().waiters.len()
245 }
246
247 #[must_use]
249 pub fn is_empty(&self) -> bool {
250 self.lock().waiters.is_empty()
251 }
252
253 #[must_use]
255 pub fn capacity(&self) -> usize {
256 self.capacity
257 }
258
259 pub fn inner(&self) -> &L {
261 &self.inner
262 }
263
264 #[inline]
265 fn lock(&self) -> MutexGuard<'_, State<K>> {
266 self.state.lock().unwrap_or_else(PoisonError::into_inner)
267 }
268
269 #[inline]
270 fn now_ms(&self) -> u64 {
271 let elapsed = self.clock.now().saturating_duration_since(self.epoch);
272 u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX)
273 }
274
275 fn register(
281 &self,
282 now_ms: u64,
283 priority: u32,
284 deadline_ms: Option<u64>,
285 key: &K,
286 ) -> Result<(u64, Arc<AtomicBool>), ThrottleError> {
287 let mut did_evict = false;
288 let outcome = {
289 let mut state = self.lock();
290 state.prune_expired(now_ms);
291
292 if state.waiters.len() < self.capacity {
293 Ok(state.insert(priority, deadline_ms, key.clone()))
294 } else {
295 match self.overflow {
296 Overflow::Reject => Err(ThrottleError::QueueFull),
297 Overflow::DropOldest => match state.oldest() {
298 Some(victim) => {
299 evict(&mut state, victim);
300 did_evict = true;
301 Ok(state.insert(priority, deadline_ms, key.clone()))
302 }
303 None => Err(ThrottleError::QueueFull),
304 },
305 Overflow::DropLowestPriority => match state.weakest() {
306 Some((victim, weakest)) if priority > weakest => {
308 evict(&mut state, victim);
309 did_evict = true;
310 Ok(state.insert(priority, deadline_ms, key.clone()))
311 }
312 _ => Err(ThrottleError::QueueFull),
313 },
314 }
315 }
316 };
317
318 if did_evict || outcome.is_ok() {
319 self.notify.notify_waiters();
320 }
321 outcome
322 }
323
324 pub async fn acquire(
339 &self,
340 key: K,
341 priority: u32,
342 deadline: Option<Duration>,
343 ) -> Result<(), ThrottleError> {
344 let start_ms = self.now_ms();
345 let deadline_ms = deadline
346 .map(|d| start_ms.saturating_add(u64::try_from(d.as_millis()).unwrap_or(u64::MAX)));
347
348 let (id, evicted) = self.register(start_ms, priority, deadline_ms, &key)?;
349 let _guard = LeaveGuard { queue: self, id };
351
352 loop {
353 let notified = self.notify.notified();
356 tokio::pin!(notified);
357 let _ = notified.as_mut().enable();
360
361 if evicted.load(Ordering::Acquire) {
362 return Err(ThrottleError::QueueFull);
363 }
364
365 let now_ms = self.now_ms();
366 if deadline_ms.is_some_and(|d| now_ms >= d) {
367 return Err(ThrottleError::DeadlineExceeded);
368 }
369
370 let wait = {
371 let mut state = self.lock();
372 if state.winner(now_ms) == Some(id) {
373 match self.inner.acquire_cost(1) {
374 Decision::Acquired => {
375 state.serve(id);
376 drop(state);
377 self.notify.notify_waiters();
378 return Ok(());
379 }
380 Decision::Impossible => {
381 return Err(ThrottleError::CostExceedsCapacity {
382 cost: 1,
383 capacity: self.inner.capacity(),
384 });
385 }
386 Decision::Retry { after } => after,
388 }
389 } else {
390 Duration::from_secs(3600)
392 }
393 };
394
395 let sleep_for = cap_to_deadline(wait, now_ms, deadline_ms);
396 tokio::select! {
397 () = notified.as_mut() => {}
398 () = tokio::time::sleep(sleep_for) => {}
399 }
400 }
401 }
402}
403
404fn cap_to_deadline(wait: Duration, now_ms: u64, deadline_ms: Option<u64>) -> Duration {
406 match deadline_ms {
407 Some(d) => wait.min(Duration::from_millis(d.saturating_sub(now_ms))),
408 None => wait,
409 }
410}
411
412fn evict<K: Eq + Hash + Clone>(state: &mut State<K>, id: u64) {
414 if let Some(w) = state.waiters.remove(&id) {
415 w.evicted.store(true, Ordering::Release);
416 }
417}
418
419struct LeaveGuard<'a, L, K, C>
421where
422 L: Limiter,
423 K: Eq + Hash + Clone + Send + Sync,
424 C: Clock + Clone,
425{
426 queue: &'a Queue<L, K, C>,
427 id: u64,
428}
429
430impl<L, K, C> Drop for LeaveGuard<'_, L, K, C>
431where
432 L: Limiter,
433 K: Eq + Hash + Clone + Send + Sync,
434 C: Clock + Clone,
435{
436 fn drop(&mut self) {
437 {
438 let mut state = self.queue.lock();
439 let _ = state.waiters.remove(&self.id);
440 }
441 self.queue.notify.notify_waiters();
443 }
444}
445
446#[derive(Debug, Clone, Copy)]
448pub struct QueueBuilder {
449 capacity: usize,
450 overflow: Overflow,
451}
452
453impl Default for QueueBuilder {
454 fn default() -> Self {
455 Self::new()
456 }
457}
458
459impl QueueBuilder {
460 #[must_use]
462 pub fn new() -> Self {
463 Self {
464 capacity: 1024,
465 overflow: Overflow::Reject,
466 }
467 }
468
469 #[must_use]
471 pub fn capacity(mut self, capacity: usize) -> Self {
472 self.capacity = capacity.max(1);
473 self
474 }
475
476 #[must_use]
478 pub fn overflow(mut self, overflow: Overflow) -> Self {
479 self.overflow = overflow;
480 self
481 }
482
483 #[must_use]
485 pub fn build<L, K>(self, limiter: L) -> Queue<L, K, SystemClock>
486 where
487 L: Limiter,
488 K: Eq + Hash + Clone + Send + Sync,
489 {
490 Queue::new(limiter, self.capacity, self.overflow, SystemClock::new())
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 #![allow(clippy::unwrap_used)]
497
498 use super::{Overflow, Queue};
499 use crate::throttle::Throttle;
500 use core::time::Duration;
501 use std::sync::Arc;
502
503 fn assert_send_sync<T: Send + Sync>() {}
504
505 #[test]
506 fn test_queue_is_send_sync() {
507 assert_send_sync::<Queue<Throttle, &'static str>>();
508 }
509
510 #[tokio::test]
511 async fn test_immediate_acquire_when_token_is_free() {
512 let queue: Queue<Throttle, ()> = Queue::builder().build(Throttle::per_second(10));
513 assert!(queue.acquire((), 0, None).await.is_ok());
514 assert!(queue.is_empty());
515 }
516
517 #[tokio::test]
518 async fn test_cost_exceeds_capacity_is_reported() {
519 let queue: Queue<Throttle, ()> = Queue::builder().build(Throttle::per_second(0));
520 let err = queue.acquire((), 0, Some(Duration::from_secs(1))).await;
521 assert!(matches!(
522 err,
523 Err(crate::ThrottleError::CostExceedsCapacity { .. })
524 ));
525 }
526
527 #[tokio::test]
528 async fn test_deadline_exceeded_when_no_token_arrives() {
529 let queue: Queue<Throttle, ()> =
532 Queue::builder().build(Throttle::per_duration(1, Duration::from_secs(3600)));
533 assert!(queue.acquire((), 0, None).await.is_ok()); let err = queue.acquire((), 0, Some(Duration::from_millis(30))).await;
536 assert!(matches!(err, Err(crate::ThrottleError::DeadlineExceeded)));
537 assert!(queue.is_empty(), "the expired waiter is removed");
538 }
539
540 #[tokio::test]
541 async fn test_reject_overflow_when_full() {
542 let queue: Arc<Queue<Throttle, ()>> = Arc::new(
545 Queue::builder()
546 .capacity(1)
547 .overflow(Overflow::Reject)
548 .build(Throttle::per_duration(1, Duration::from_secs(3600))),
549 );
550 assert!(queue.acquire((), 0, None).await.is_ok()); let q = Arc::clone(&queue);
553 let parked = tokio::spawn(async move { q.acquire((), 0, None).await });
554 while queue.is_empty() {
555 tokio::task::yield_now().await;
556 }
557 let rejected = queue.acquire((), 0, Some(Duration::from_secs(1))).await;
558 assert!(matches!(rejected, Err(crate::ThrottleError::QueueFull)));
559 parked.abort();
560 }
561
562 #[tokio::test]
563 async fn test_drop_oldest_overflow_evicts_the_first_waiter() {
564 let queue: Arc<Queue<Throttle, ()>> = Arc::new(
565 Queue::builder()
566 .capacity(1)
567 .overflow(Overflow::DropOldest)
568 .build(Throttle::per_duration(1, Duration::from_secs(3600))),
569 );
570 assert!(queue.acquire((), 0, None).await.is_ok()); let q = Arc::clone(&queue);
574 let first = tokio::spawn(async move { q.acquire((), 0, None).await });
575 while queue.is_empty() {
576 tokio::task::yield_now().await;
577 }
578 let q = Arc::clone(&queue);
580 let second = tokio::spawn(async move { q.acquire((), 0, None).await });
581 let first_result = first.await.unwrap();
582 assert!(matches!(first_result, Err(crate::ThrottleError::QueueFull)));
583 second.abort();
584 }
585
586 #[tokio::test]
587 async fn test_priority_is_served_high_first() {
588 use std::sync::atomic::{AtomicU32, Ordering};
589
590 let queue: Arc<Queue<Throttle, ()>> = Arc::new(
594 Queue::builder()
595 .capacity(10)
596 .build(Throttle::per_duration(1, Duration::from_millis(50))),
597 );
598 assert!(queue.acquire((), 0, None).await.is_ok()); let order = Arc::new(std::sync::Mutex::new(Vec::new()));
601 let started = Arc::new(AtomicU32::new(0));
602
603 let mut handles = Vec::new();
604 for priority in [1u32, 5, 3] {
605 let q = Arc::clone(&queue);
606 let order = Arc::clone(&order);
607 let started = Arc::clone(&started);
608 handles.push(tokio::spawn(async move {
609 let _ = started.fetch_add(1, Ordering::Relaxed);
610 q.acquire((), priority, None).await.unwrap();
611 order.lock().unwrap().push(priority);
612 }));
613 }
614 while queue.len() < 3 {
616 tokio::task::yield_now().await;
617 }
618 for h in handles {
619 h.await.unwrap();
620 }
621
622 assert_eq!(*order.lock().unwrap(), vec![5, 3, 1]);
623 }
624
625 #[test]
626 fn test_fair_winner_rotates_across_keys_at_equal_priority() {
627 use super::{State, Waiter};
628 use std::sync::atomic::AtomicBool;
629
630 fn enqueue(state: &mut State<&'static str>, id: u64, priority: u32, key: &'static str) {
631 let _ = state.waiters.insert(
632 id,
633 Waiter {
634 seq: id,
635 priority,
636 deadline_ms: None,
637 key,
638 evicted: Arc::new(AtomicBool::new(false)),
639 },
640 );
641 }
642
643 let mut state = State::<&'static str>::new();
644 enqueue(&mut state, 0, 0, "a");
646 enqueue(&mut state, 1, 0, "a");
647 enqueue(&mut state, 2, 0, "b");
648
649 assert_eq!(state.winner(0), Some(0));
651 state.serve(0);
652 assert_eq!(state.winner(0), Some(2));
655 state.serve(2);
656 assert_eq!(state.winner(0), Some(1));
658 }
659
660 #[test]
661 fn test_priority_beats_fairness_in_winner_selection() {
662 use super::{State, Waiter};
663 use std::sync::atomic::AtomicBool;
664
665 let mut state = State::<&'static str>::new();
666 let _ = state.waiters.insert(
667 0,
668 Waiter {
669 seq: 0,
670 priority: 1,
671 deadline_ms: None,
672 key: "a",
673 evicted: Arc::new(AtomicBool::new(false)),
674 },
675 );
676 let _ = state.waiters.insert(
677 1,
678 Waiter {
679 seq: 1,
680 priority: 9,
681 deadline_ms: None,
682 key: "b",
683 evicted: Arc::new(AtomicBool::new(false)),
684 },
685 );
686 assert_eq!(state.winner(0), Some(1));
688 }
689
690 #[test]
691 fn test_winner_skips_expired_waiters() {
692 use super::{State, Waiter};
693 use std::sync::atomic::AtomicBool;
694
695 let mut state = State::<&'static str>::new();
696 let _ = state.waiters.insert(
697 0,
698 Waiter {
699 seq: 0,
700 priority: 9,
701 deadline_ms: Some(100),
702 key: "a",
703 evicted: Arc::new(AtomicBool::new(false)),
704 },
705 );
706 let _ = state.waiters.insert(
707 1,
708 Waiter {
709 seq: 1,
710 priority: 1,
711 deadline_ms: None,
712 key: "b",
713 evicted: Arc::new(AtomicBool::new(false)),
714 },
715 );
716 assert_eq!(state.winner(200), Some(1));
718 }
719}