twilight_http_ratelimiting/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(
3    clippy::missing_const_for_fn,
4    clippy::missing_docs_in_private_items,
5    clippy::pedantic,
6    missing_docs,
7    unsafe_code
8)]
9#![allow(clippy::module_name_repetitions, clippy::must_use_candidate)]
10
11mod actor;
12
13use std::{
14    future::Future,
15    hash::{Hash as _, Hasher},
16    pin::Pin,
17    task::{Context, Poll},
18    time::{Duration, Instant},
19};
20use tokio::sync::{mpsc, oneshot};
21
22/// Duration from the first globally limited request until the remaining count
23/// resets to the global limit count.
24pub const GLOBAL_LIMIT_PERIOD: Duration = Duration::from_secs(1);
25
26/// HTTP request [method].
27///
28/// [method]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
29#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
30#[non_exhaustive]
31pub enum Method {
32    /// Delete a resource.
33    Delete,
34    /// Retrieve a resource.
35    Get,
36    /// Update a resource.
37    Patch,
38    /// Create a resource.
39    Post,
40    /// Replace a resource.
41    Put,
42}
43
44impl Method {
45    /// Name of the method.
46    pub const fn name(self) -> &'static str {
47        match self {
48            Method::Delete => "DELETE",
49            Method::Get => "GET",
50            Method::Patch => "PATCH",
51            Method::Post => "POST",
52            Method::Put => "PUT",
53        }
54    }
55}
56
57/// Rate limited endpoint.
58///
59/// The rate limiter dynamically supports new or unknown API paths, but is consequently unable to
60/// catch invalid arguments. Invalidly structured endpoints may be permitted at an improper time.
61///
62/// # Example
63///
64/// ```no_run
65/// # let rt = tokio::runtime::Builder::new_current_thread()
66/// #     .enable_time()
67/// #     .build()
68/// #     .unwrap();
69/// # rt.block_on(async {
70/// # let rate_limiter = twilight_http_ratelimiting::RateLimiter::default();
71/// use twilight_http_ratelimiting::{Endpoint, Method};
72///
73/// let url = "https://discord.com/api/v10/guilds/745809834183753828/audit-logs?limit=10";
74/// let endpoint = Endpoint {
75///     method: Method::Get,
76///     path: String::from("guilds/745809834183753828/audit-logs"),
77/// };
78/// let permit = rate_limiter.acquire(endpoint).await;
79/// let headers = unimplemented!("GET {url}");
80/// permit.complete(headers);
81/// # });
82/// ```
83#[derive(Clone, Debug, Eq, Hash, PartialEq)]
84pub struct Endpoint {
85    /// Method of the endpoint.
86    pub method: Method,
87    /// API path of the endpoint.
88    ///
89    /// Should not start with a slash (`/`) or include query parameters (`?key=value`).
90    pub path: String,
91}
92
93impl Endpoint {
94    /// Whether the endpoint is properly structured.
95    pub(crate) fn is_valid(&self) -> bool {
96        !self.path.as_bytes().starts_with(b"/") && !self.path.as_bytes().contains(&b'?')
97    }
98
99    /// Whether the endpoint is an interaction.
100    pub(crate) fn is_interaction(&self) -> bool {
101        self.path.as_bytes().starts_with(b"webhooks")
102            || self.path.as_bytes().starts_with(b"interactions")
103    }
104
105    /// Feeds the top-level resources of this endpoint into the given [`Hasher`].
106    ///
107    /// Top-level resources represent the bucket namespace in which they are unique.
108    ///
109    /// Top-level resources are currently:
110    /// - `channels/<channel_id>`
111    /// - `guilds/<guild_id>`
112    /// - `webhooks/<webhook_id>`
113    /// - `webhooks/<webhook_id>/<webhook_token>`
114    pub(crate) fn hash_resources(&self, state: &mut impl Hasher) {
115        let mut segments = self.path.as_bytes().split(|&s| s == b'/');
116        match segments.next().unwrap_or_default() {
117            b"channels" => {
118                if let Some(s) = segments.next() {
119                    "channels".hash(state);
120                    s.hash(state);
121                }
122            }
123            b"guilds" => {
124                if let Some(s) = segments.next() {
125                    "guilds".hash(state);
126                    s.hash(state);
127                }
128            }
129            b"webhooks" => {
130                if let Some(s) = segments.next() {
131                    "webhooks".hash(state);
132                    s.hash(state);
133                }
134                if let Some(s) = segments.next() {
135                    s.hash(state);
136                }
137            }
138            _ => {}
139        }
140    }
141}
142
143/// Parsed user response rate limit headers.
144///
145/// A `limit` of zero marks the [`Bucket`] as exhausted until `reset_at` elapses.
146///
147/// # Global limits
148///
149/// Please open an issue if the [`RateLimiter`] exceeded the global limit.
150///
151/// # Shared limits
152///
153/// You may preemptively exhaust the bucket until `Reset-After` by completing
154/// the [`Permit`] with [`RateLimitHeaders::shared`], but are not required to
155/// since these limits do not count towards the invalid request limit.
156#[derive(Clone, Debug, Eq, Hash, PartialEq)]
157pub struct RateLimitHeaders {
158    /// Bucket identifier.
159    pub bucket: Vec<u8>,
160    /// Total number of requests until the bucket becomes exhausted.
161    pub limit: u16,
162    /// Number of remaining requests until the bucket becomes exhausted.
163    pub remaining: u16,
164    /// Time at which the bucket resets.
165    pub reset_at: Instant,
166}
167
168impl RateLimitHeaders {
169    /// Lowercased name for the bucket header.
170    pub const BUCKET: &'static str = "x-ratelimit-bucket";
171
172    /// Lowercased name for the limit header.
173    pub const LIMIT: &'static str = "x-ratelimit-limit";
174
175    /// Lowercased name for the remaining header.
176    pub const REMAINING: &'static str = "x-ratelimit-remaining";
177
178    /// Lowercased name for the reset-after header.
179    pub const RESET_AFTER: &'static str = "x-ratelimit-reset-after";
180
181    /// Lowercased name for the scope header.
182    pub const SCOPE: &'static str = "x-ratelimit-scope";
183
184    /// Emulates a shared resource limit as a user limit by setting `limit` and
185    /// `remaining` to zero.
186    pub fn shared(bucket: Vec<u8>, retry_after: u16) -> Self {
187        Self {
188            bucket,
189            limit: 0,
190            remaining: 0,
191            reset_at: Instant::now() + Duration::from_secs(retry_after.into()),
192        }
193    }
194}
195
196/// Permit to send a Discord HTTP API request to the acquired endpoint.
197#[derive(Debug)]
198#[must_use = "dropping the permit immediately cancels itself"]
199pub struct Permit(oneshot::Sender<Option<RateLimitHeaders>>);
200
201impl Permit {
202    /// Update the [`RateLimiter`] based on the response headers.
203    ///
204    /// Non-completed permits are regarded as cancelled, so only call this
205    /// on receiving a response.
206    #[allow(clippy::missing_panics_doc)]
207    pub fn complete(self, headers: Option<RateLimitHeaders>) {
208        self.0.send(headers).expect("actor is alive");
209    }
210}
211
212/// Future that completes when a permit is ready.
213#[derive(Debug)]
214#[must_use = "futures do nothing unless you `.await` or poll them"]
215pub struct PermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
216
217impl Future for PermitFuture {
218    type Output = Permit;
219
220    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
221        Pin::new(&mut self.0)
222            .poll(cx)
223            .map(|r| Permit(r.expect("actor is alive")))
224    }
225}
226
227/// Future that completes when a permit is ready or cancelled.
228#[derive(Debug)]
229#[must_use = "futures do nothing unless you `.await` or poll them"]
230pub struct MaybePermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
231
232impl Future for MaybePermitFuture {
233    type Output = Option<Permit>;
234
235    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
236        Pin::new(&mut self.0).poll(cx).map(|r| r.ok().map(Permit))
237    }
238}
239
240/// Rate limit information for one or more paths from previous
241/// [`RateLimitHeaders`].
242#[non_exhaustive]
243#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
244pub struct Bucket {
245    /// Total number of permits until the bucket becomes exhausted.
246    pub limit: u16,
247    /// Number of remaining permits until the bucket becomes exhausted.
248    pub remaining: u16,
249    /// Time at which the bucket resets.
250    pub reset_at: Instant,
251}
252
253/// Actor run closure pre-enqueue for early [`MaybePermitFuture`] cancellation.
254type Predicate = Box<dyn FnOnce(Option<Bucket>) -> bool + Send>;
255
256/// Discord HTTP client API rate limiter.
257///
258/// The [`RateLimiter`] runs an associated actor task to concurrently handle permit
259/// requests and responses.
260///
261/// Cloning a [`RateLimiter`] increments just the amount of senders for the actor.
262/// The actor completes when there are no senders and non-completed permits left.
263#[derive(Clone, Debug)]
264pub struct RateLimiter {
265    /// Actor message sender.
266    tx: mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
267}
268
269impl RateLimiter {
270    /// Create a new [`RateLimiter`] with a custom global limit.
271    pub fn new(global_limit: u16) -> Self {
272        let (tx, rx) = mpsc::unbounded_channel();
273        tokio::spawn(actor::runner(global_limit, rx));
274
275        Self { tx }
276    }
277
278    /// Await a single permit for this endpoint.
279    ///
280    /// Permits are queued per endpoint in the order they were requested.
281    #[allow(clippy::missing_panics_doc)]
282    pub fn acquire(&self, endpoint: Endpoint) -> PermitFuture {
283        let (tx, rx) = oneshot::channel();
284        self.tx
285            .send((
286                actor::Message {
287                    endpoint,
288                    notifier: tx,
289                },
290                None,
291            ))
292            .expect("actor is alive");
293
294        PermitFuture(rx)
295    }
296
297    /// Await a single permit for this endpoint, but only if the predicate evaluates
298    /// to `true`.
299    ///
300    /// Permits are queued per endpoint in the order they were requested.
301    ///
302    /// Note that the predicate is asynchronously called in the actor task.
303    ///
304    /// # Example
305    ///
306    /// ```no_run
307    /// # let rt = tokio::runtime::Builder::new_current_thread()
308    /// #     .enable_time()
309    /// #     .build()
310    /// #     .unwrap();
311    /// # rt.block_on(async {
312    /// # let rate_limiter = twilight_http_ratelimiting::RateLimiter::default();
313    /// use twilight_http_ratelimiting::{Endpoint, Method};
314    ///
315    /// let endpoint = Endpoint {
316    ///     method: Method::Get,
317    ///     path: String::from("applications/@me"),
318    /// };
319    /// if let Some(permit) = rate_limiter
320    ///     .acquire_if(endpoint, |b| b.is_none_or(|b| b.remaining > 10))
321    ///     .await
322    /// {
323    ///     let headers = unimplemented!("GET /applications/@me");
324    ///     permit.complete(headers);
325    /// }
326    /// # });
327    /// ```
328    #[allow(clippy::missing_panics_doc)]
329    pub fn acquire_if<P>(&self, endpoint: Endpoint, predicate: P) -> MaybePermitFuture
330    where
331        P: FnOnce(Option<Bucket>) -> bool + Send + 'static,
332    {
333        let (tx, rx) = oneshot::channel();
334        self.tx
335            .send((
336                actor::Message {
337                    endpoint,
338                    notifier: tx,
339                },
340                Some(Box::new(predicate)),
341            ))
342            .expect("actor is alive");
343
344        MaybePermitFuture(rx)
345    }
346
347    /// Retrieve the [`Bucket`] for this endpoint.
348    ///
349    /// The bucket is internally retrieved via [`acquire_if`][Self::acquire_if].
350    #[allow(clippy::missing_panics_doc)]
351    pub async fn bucket(&self, endpoint: Endpoint) -> Option<Bucket> {
352        let (tx, rx) = oneshot::channel();
353        self.acquire_if(endpoint, |bucket| {
354            _ = tx.send(bucket);
355            false
356        })
357        .await;
358
359        rx.await.expect("actor is alive")
360    }
361}
362
363impl Default for RateLimiter {
364    /// Create a new [`RateLimiter`] with Discord's default global limit.
365    ///
366    /// Currently this is `50`.
367    fn default() -> Self {
368        Self::new(50)
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::{
375        Bucket, Endpoint, MaybePermitFuture, Method, Permit, PermitFuture, RateLimitHeaders,
376        RateLimiter,
377    };
378    use static_assertions::assert_impl_all;
379    use std::{
380        fmt::Debug,
381        future::Future,
382        hash::{DefaultHasher, Hash, Hasher as _},
383        time::{Duration, Instant},
384    };
385    use tokio::task;
386
387    assert_impl_all!(Bucket: Clone, Copy, Debug, Eq, Hash, PartialEq, Send, Sync);
388    assert_impl_all!(Endpoint: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
389    assert_impl_all!(MaybePermitFuture: Debug, Future<Output = Option<Permit>>);
390    assert_impl_all!(Method: Clone, Copy, Debug, Eq, PartialEq);
391    assert_impl_all!(Permit: Debug, Send, Sync);
392    assert_impl_all!(PermitFuture: Debug, Future<Output = Permit>);
393    assert_impl_all!(RateLimitHeaders: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
394    assert_impl_all!(RateLimiter: Clone, Debug, Default, Send, Sync);
395
396    const ENDPOINT: fn() -> Endpoint = || Endpoint {
397        method: Method::Get,
398        path: String::from("applications/@me"),
399    };
400
401    #[tokio::test]
402    async fn acquire_if() {
403        let rate_limiter = RateLimiter::default();
404
405        assert!(
406            rate_limiter
407                .acquire_if(ENDPOINT(), |_| false)
408                .await
409                .is_none()
410        );
411        assert!(
412            rate_limiter
413                .acquire_if(ENDPOINT(), |_| true)
414                .await
415                .is_some()
416        );
417    }
418
419    #[tokio::test]
420    async fn bucket() {
421        let rate_limiter = RateLimiter::default();
422
423        let limit = 2;
424        let remaining = 1;
425        let reset_at = Instant::now() + Duration::from_secs(1);
426        let headers = RateLimitHeaders {
427            bucket: vec![1, 2, 3],
428            limit,
429            remaining,
430            reset_at,
431        };
432
433        rate_limiter
434            .acquire(ENDPOINT())
435            .await
436            .complete(Some(headers));
437        task::yield_now().await;
438
439        let bucket = rate_limiter.bucket(ENDPOINT()).await.unwrap();
440        assert_eq!(bucket.limit, limit);
441        assert_eq!(bucket.remaining, remaining);
442        assert!(
443            bucket.reset_at.saturating_duration_since(reset_at) < Duration::from_millis(1)
444                && reset_at.saturating_duration_since(bucket.reset_at) < Duration::from_millis(1)
445        );
446    }
447
448    fn with_hasher(f: impl FnOnce(&mut DefaultHasher)) -> u64 {
449        let mut hasher = DefaultHasher::new();
450        f(&mut hasher);
451        hasher.finish()
452    }
453
454    #[test]
455    fn endpoint() {
456        let invalid = Endpoint {
457            method: Method::Get,
458            path: String::from("/guilds/745809834183753828/audit-logs?limit=10"),
459        };
460        let delete_webhook = Endpoint {
461            method: Method::Delete,
462            path: String::from("webhooks/1"),
463        };
464        let interaction_response = Endpoint {
465            method: Method::Post,
466            path: String::from("interactions/1/abc/callback"),
467        };
468
469        assert!(!invalid.is_valid());
470        assert!(delete_webhook.is_valid());
471        assert!(interaction_response.is_valid());
472
473        assert!(delete_webhook.is_interaction());
474        assert!(interaction_response.is_interaction());
475
476        assert_eq!(
477            with_hasher(|state| invalid.hash_resources(state)),
478            with_hasher(|_| {})
479        );
480        assert_eq!(
481            with_hasher(|state| delete_webhook.hash_resources(state)),
482            with_hasher(|state| {
483                "webhooks".hash(state);
484                b"1".hash(state);
485            })
486        );
487        assert_eq!(
488            with_hasher(|state| interaction_response.hash_resources(state)),
489            with_hasher(|_| {})
490        );
491    }
492
493    #[test]
494    fn method_conversions() {
495        assert_eq!("DELETE", Method::Delete.name());
496        assert_eq!("GET", Method::Get.name());
497        assert_eq!("PATCH", Method::Patch.name());
498        assert_eq!("POST", Method::Post.name());
499        assert_eq!("PUT", Method::Put.name());
500    }
501}