twilight_http_ratelimiting/
lib.rs1#![doc = include_str!("../README.md")]
2#![warn(
3 clippy::missing_const_for_fn,
4 clippy::missing_docs_in_private_items,
5 clippy::pedantic,
6 missing_docs,
7 unsafe_code
8)]
9#![allow(clippy::module_name_repetitions, clippy::must_use_candidate)]
10
11mod actor;
12
13use std::{
14 future::Future,
15 hash::{Hash as _, Hasher},
16 pin::Pin,
17 task::{Context, Poll},
18 time::{Duration, Instant},
19};
20use tokio::sync::{mpsc, oneshot};
21
22pub const GLOBAL_LIMIT_PERIOD: Duration = Duration::from_secs(1);
25
26const ACTOR_PANIC_MESSAGE: &str =
28 "actor task panicked: report its panic message to the maintainers";
29
30#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
34#[non_exhaustive]
35pub enum Method {
36 Delete,
38 Get,
40 Patch,
42 Post,
44 Put,
46}
47
48impl Method {
49 pub const fn name(self) -> &'static str {
51 match self {
52 Method::Delete => "DELETE",
53 Method::Get => "GET",
54 Method::Patch => "PATCH",
55 Method::Post => "POST",
56 Method::Put => "PUT",
57 }
58 }
59}
60
61#[derive(Clone, Debug, Eq, Hash, PartialEq)]
88pub struct Endpoint {
89 pub method: Method,
91 pub path: String,
95}
96
97impl Endpoint {
98 pub(crate) fn is_valid(&self) -> bool {
100 !self.path.as_bytes().starts_with(b"/") && !self.path.as_bytes().contains(&b'?')
101 }
102
103 pub(crate) fn is_interaction(&self) -> bool {
105 self.path.as_bytes().starts_with(b"webhooks")
106 || self.path.as_bytes().starts_with(b"interactions")
107 }
108
109 pub(crate) fn hash_resources(&self, state: &mut impl Hasher) {
119 let mut segments = self.path.as_bytes().split(|&s| s == b'/');
120 match segments.next().unwrap_or_default() {
121 b"channels" => {
122 if let Some(s) = segments.next() {
123 "channels".hash(state);
124 s.hash(state);
125 }
126 }
127 b"guilds" => {
128 if let Some(s) = segments.next() {
129 "guilds".hash(state);
130 s.hash(state);
131 }
132 }
133 b"webhooks" => {
134 if let Some(s) = segments.next() {
135 "webhooks".hash(state);
136 s.hash(state);
137 }
138 if let Some(s) = segments.next() {
139 s.hash(state);
140 }
141 }
142 _ => {}
143 }
144 }
145}
146
147#[derive(Clone, Debug, Eq, Hash, PartialEq)]
161pub struct RateLimitHeaders {
162 pub bucket: Vec<u8>,
164 pub limit: u16,
166 pub remaining: u16,
168 pub reset_at: Instant,
170}
171
172impl RateLimitHeaders {
173 pub const BUCKET: &'static str = "x-ratelimit-bucket";
175
176 pub const LIMIT: &'static str = "x-ratelimit-limit";
178
179 pub const REMAINING: &'static str = "x-ratelimit-remaining";
181
182 pub const RESET_AFTER: &'static str = "x-ratelimit-reset-after";
184
185 pub const SCOPE: &'static str = "x-ratelimit-scope";
187
188 pub fn shared(bucket: Vec<u8>, retry_after: u16) -> Self {
191 Self {
192 bucket,
193 limit: 0,
194 remaining: 0,
195 reset_at: Instant::now() + Duration::from_secs(retry_after.into()),
196 }
197 }
198}
199
200#[derive(Debug)]
202#[must_use = "dropping the permit immediately cancels itself"]
203pub struct Permit(oneshot::Sender<Option<RateLimitHeaders>>);
204
205impl Permit {
206 #[allow(clippy::missing_panics_doc)]
211 pub fn complete(self, headers: Option<RateLimitHeaders>) {
212 assert!(self.0.send(headers).is_ok(), "{ACTOR_PANIC_MESSAGE}");
213 }
214}
215
216#[derive(Debug)]
218#[must_use = "futures do nothing unless you `.await` or poll them"]
219pub struct PermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
220
221impl Future for PermitFuture {
222 type Output = Permit;
223
224 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
225 #[allow(clippy::match_wild_err_arm)]
226 Pin::new(&mut self.0).poll(cx).map(|r| match r {
227 Ok(sender) => Permit(sender),
228 Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
229 })
230 }
231}
232
233#[derive(Debug)]
235#[must_use = "futures do nothing unless you `.await` or poll them"]
236pub struct MaybePermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
237
238impl Future for MaybePermitFuture {
239 type Output = Option<Permit>;
240
241 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
242 Pin::new(&mut self.0).poll(cx).map(|r| r.ok().map(Permit))
243 }
244}
245
246#[non_exhaustive]
249#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
250pub struct Bucket {
251 pub limit: u16,
253 pub remaining: u16,
255 pub reset_at: Instant,
257}
258
259type Predicate = Box<dyn FnOnce(Option<Bucket>) -> bool + Send>;
261
262#[derive(Clone, Debug)]
270pub struct RateLimiter {
271 tx: mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
273}
274
275impl RateLimiter {
276 pub fn new(global_limit: u16) -> Self {
278 let (tx, rx) = mpsc::unbounded_channel();
279 tokio::spawn(actor::runner(global_limit, rx));
280
281 Self { tx }
282 }
283
284 #[allow(clippy::missing_panics_doc)]
288 pub fn acquire(&self, endpoint: Endpoint) -> PermitFuture {
289 let (notifier, rx) = oneshot::channel();
290 let message = actor::Message { endpoint, notifier };
291 assert!(
292 self.tx.send((message, None)).is_ok(),
293 "{ACTOR_PANIC_MESSAGE}"
294 );
295
296 PermitFuture(rx)
297 }
298
299 #[allow(clippy::missing_panics_doc)]
331 pub fn acquire_if<P>(&self, endpoint: Endpoint, predicate: P) -> MaybePermitFuture
332 where
333 P: FnOnce(Option<Bucket>) -> bool + Send + 'static,
334 {
335 let (notifier, rx) = oneshot::channel();
336 let message = actor::Message { endpoint, notifier };
337 assert!(
338 self.tx.send((message, Some(Box::new(predicate)))).is_ok(),
339 "{ACTOR_PANIC_MESSAGE}"
340 );
341
342 MaybePermitFuture(rx)
343 }
344
345 #[allow(clippy::missing_panics_doc)]
349 pub async fn bucket(&self, endpoint: Endpoint) -> Option<Bucket> {
350 let (tx, rx) = oneshot::channel();
351 self.acquire_if(endpoint, |bucket| {
352 _ = tx.send(bucket);
354 false
355 })
356 .await;
357
358 #[allow(clippy::match_wild_err_arm)]
359 match rx.await {
360 Ok(bucket) => bucket,
361 Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
362 }
363 }
364}
365
366impl Default for RateLimiter {
367 fn default() -> Self {
371 Self::new(50)
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::{
378 Bucket, Endpoint, MaybePermitFuture, Method, Permit, PermitFuture, RateLimitHeaders,
379 RateLimiter,
380 };
381 use static_assertions::assert_impl_all;
382 use std::{
383 fmt::Debug,
384 future::Future,
385 hash::{DefaultHasher, Hash, Hasher as _},
386 time::{Duration, Instant},
387 };
388 use tokio::task;
389
390 assert_impl_all!(Bucket: Clone, Copy, Debug, Eq, Hash, PartialEq, Send, Sync);
391 assert_impl_all!(Endpoint: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
392 assert_impl_all!(MaybePermitFuture: Debug, Future<Output = Option<Permit>>);
393 assert_impl_all!(Method: Clone, Copy, Debug, Eq, PartialEq);
394 assert_impl_all!(Permit: Debug, Send, Sync);
395 assert_impl_all!(PermitFuture: Debug, Future<Output = Permit>);
396 assert_impl_all!(RateLimitHeaders: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
397 assert_impl_all!(RateLimiter: Clone, Debug, Default, Send, Sync);
398
399 const ENDPOINT: fn() -> Endpoint = || Endpoint {
400 method: Method::Get,
401 path: String::from("applications/@me"),
402 };
403
404 #[tokio::test]
405 async fn acquire_if() {
406 let rate_limiter = RateLimiter::default();
407
408 assert!(
409 rate_limiter
410 .acquire_if(ENDPOINT(), |_| false)
411 .await
412 .is_none()
413 );
414 assert!(
415 rate_limiter
416 .acquire_if(ENDPOINT(), |_| true)
417 .await
418 .is_some()
419 );
420 }
421
422 #[tokio::test]
423 async fn bucket() {
424 let rate_limiter = RateLimiter::default();
425
426 let limit = 2;
427 let remaining = 1;
428 let reset_at = Instant::now() + Duration::from_secs(1);
429 let headers = RateLimitHeaders {
430 bucket: vec![1, 2, 3],
431 limit,
432 remaining,
433 reset_at,
434 };
435
436 rate_limiter
437 .acquire(ENDPOINT())
438 .await
439 .complete(Some(headers));
440 task::yield_now().await;
441
442 let bucket = rate_limiter.bucket(ENDPOINT()).await.unwrap();
443 assert_eq!(bucket.limit, limit);
444 assert_eq!(bucket.remaining, remaining);
445 assert!(
446 bucket.reset_at.saturating_duration_since(reset_at) < Duration::from_millis(1)
447 && reset_at.saturating_duration_since(bucket.reset_at) < Duration::from_millis(1)
448 );
449 }
450
451 fn with_hasher(f: impl FnOnce(&mut DefaultHasher)) -> u64 {
452 let mut hasher = DefaultHasher::new();
453 f(&mut hasher);
454 hasher.finish()
455 }
456
457 #[test]
458 fn endpoint() {
459 let invalid = Endpoint {
460 method: Method::Get,
461 path: String::from("/guilds/745809834183753828/audit-logs?limit=10"),
462 };
463 let delete_webhook = Endpoint {
464 method: Method::Delete,
465 path: String::from("webhooks/1"),
466 };
467 let interaction_response = Endpoint {
468 method: Method::Post,
469 path: String::from("interactions/1/abc/callback"),
470 };
471
472 assert!(!invalid.is_valid());
473 assert!(delete_webhook.is_valid());
474 assert!(interaction_response.is_valid());
475
476 assert!(delete_webhook.is_interaction());
477 assert!(interaction_response.is_interaction());
478
479 assert_eq!(
480 with_hasher(|state| invalid.hash_resources(state)),
481 with_hasher(|_| {})
482 );
483 assert_eq!(
484 with_hasher(|state| delete_webhook.hash_resources(state)),
485 with_hasher(|state| {
486 "webhooks".hash(state);
487 b"1".hash(state);
488 })
489 );
490 assert_eq!(
491 with_hasher(|state| interaction_response.hash_resources(state)),
492 with_hasher(|_| {})
493 );
494 }
495
496 #[test]
497 fn method_conversions() {
498 assert_eq!("DELETE", Method::Delete.name());
499 assert_eq!("GET", Method::Get.name());
500 assert_eq!("PATCH", Method::Patch.name());
501 assert_eq!("POST", Method::Post.name());
502 assert_eq!("PUT", Method::Put.name());
503 }
504}