twilight_http/response/future.rs
1use super::{BytesFuture, Response};
2use crate::{
3 api_error::ApiError,
4 client::connector::Connector,
5 error::{Error, ErrorType},
6};
7use http::{HeaderMap, HeaderValue, Request, StatusCode, header};
8use http_body_util::Full;
9use hyper::body::Bytes;
10use hyper_util::client::legacy::{Client as HyperClient, ResponseFuture as HyperResponseFuture};
11use std::{
12 future::{Future, Ready, ready},
13 marker::PhantomData,
14 pin::Pin,
15 sync::{
16 Arc,
17 atomic::{AtomicBool, Ordering},
18 },
19 task::{Context, Poll, ready},
20 time::{Duration, Instant},
21};
22use tokio::time::{self, Timeout};
23use twilight_http_ratelimiting::{Endpoint, Permit, PermitFuture, RateLimitHeaders, RateLimiter};
24
25/// Parse ratelimit headers from a map of headers.
26///
27/// # Errors
28///
29/// Errors if a required header is missing or if a header value is of an
30/// invalid type.
31fn parse_ratelimit_headers(
32 headers: &HeaderMap,
33) -> Result<Option<RateLimitHeaders>, Box<dyn std::error::Error>> {
34 match headers
35 .get(RateLimitHeaders::SCOPE)
36 .map(HeaderValue::as_bytes)
37 {
38 Some(b"global") => {
39 tracing::info!("globally rate limited");
40
41 Ok(None)
42 }
43 Some(b"shared") => {
44 let bucket = headers
45 .get(RateLimitHeaders::BUCKET)
46 .ok_or("missing bucket header")?
47 .as_bytes()
48 .to_vec();
49 let retry_after = headers
50 .get(header::RETRY_AFTER)
51 .ok_or("missing retry-after header")?
52 .to_str()?
53 .parse()?;
54
55 Ok(Some(RateLimitHeaders::shared(bucket, retry_after)))
56 }
57 Some(b"user") => {
58 let bucket = headers
59 .get(RateLimitHeaders::BUCKET)
60 .ok_or("missing bucket header")?
61 .as_bytes()
62 .to_vec();
63 let limit = headers
64 .get(RateLimitHeaders::LIMIT)
65 .ok_or("missing limit header")?
66 .to_str()?
67 .parse()?;
68 let remaining = headers
69 .get(RateLimitHeaders::REMAINING)
70 .ok_or("missing remaining header")?
71 .to_str()?
72 .parse()?;
73 let reset_after = headers
74 .get(RateLimitHeaders::RESET_AFTER)
75 .ok_or("missing reset-after header")?
76 .to_str()?
77 .parse()?;
78
79 Ok(Some(RateLimitHeaders {
80 bucket,
81 limit,
82 remaining,
83 reset_at: Instant::now() + Duration::from_secs_f32(reset_after),
84 }))
85 }
86 _ => Ok(None),
87 }
88}
89
90/// Sub-futures of [`ResponseFuture`].
91enum ResponseStageFuture {
92 /// Future that completes with an error response body.
93 Error {
94 /// Inner response body future.
95 fut: BytesFuture,
96 /// Erroneous response status code.
97 status: StatusCode,
98 },
99 /// Future that completes when a rate limit permit is ready.
100 RateLimitPermit(PermitFuture),
101 /// Future that completes with a response or timeout.
102 Response {
103 /// Inner timed response future.
104 fut: Pin<Box<Timeout<HyperResponseFuture>>>,
105 /// Optional rate limit permit.
106 permit: Option<Permit>,
107 },
108}
109
110/// [`PermitFuture`] generator.
111struct PermitFutureGenerator {
112 /// Rate limiter to acquire permits from.
113 rate_limiter: RateLimiter,
114 /// Rate limiter endpoint to acquire permits for.
115 endpoint: Endpoint,
116}
117
118impl PermitFutureGenerator {
119 /// Generates a permit future.
120 fn generate(&self) -> PermitFuture {
121 self.rate_limiter.acquire(self.endpoint.clone())
122 }
123}
124
125/// [`Timeout<HyperResponseFuture>`] generator.
126struct TimedResponseFutureGenerator {
127 /// HTTP client to send requests from.
128 client: HyperClient<Connector, Full<Bytes>>,
129 /// HTTP request to send.
130 request: Request<Full<Bytes>>,
131 /// Duration after which the request times out.
132 timeout: Duration,
133}
134
135impl TimedResponseFutureGenerator {
136 /// Generates a timeout response future.
137 fn generate(&self) -> Pin<Box<Timeout<HyperResponseFuture>>> {
138 Box::pin(time::timeout(
139 self.timeout,
140 self.client.request(self.request.clone()),
141 ))
142 }
143}
144
145/// Future that completes when a [`Response`] is received.
146///
147/// # Rate limits
148///
149/// Requests that exceed a rate limit are automatically and immediately retried
150/// until they succeed or fail with another error. If configured without a
151/// [`RateLimiter`], care must be taken that an external service intercepts and
152/// delays these retry requests.
153///
154/// # Canceling a response future pre-flight
155///
156/// Response futures can be canceled pre-flight via
157/// [`ResponseFuture::set_pre_flight`]. This allows you to cancel requests that
158/// are no longer necessary once they have been cleared by the ratelimit queue,
159/// which may be necessary in scenarios where requests are being spammed. Refer
160/// to its documentation for more information.
161///
162/// # Errors
163///
164/// Returns an [`ErrorType::Parsing`] error type if the request failed and the
165/// error in the response body could not be deserialized.
166///
167/// Returns an [`ErrorType::RequestCanceled`] error type if the request was
168/// canceled by the user.
169///
170/// Returns an [`ErrorType::RequestError`] error type if creating the request
171/// failed.
172///
173/// Returns an [`ErrorType::RequestTimedOut`] error type if the request timed
174/// out. The timeout value is configured via [`ClientBuilder::timeout`].
175///
176/// Returns an [`ErrorType::Response`] error type if the request failed.
177///
178/// [`ClientBuilder::timeout`]: crate::client::ClientBuilder::timeout
179/// [`ErrorType::Json`]: crate::error::ErrorType::Json
180/// [`ErrorType::Parsing`]: crate::error::ErrorType::Parsing
181/// [`ErrorType::RequestCanceled`]: crate::error::ErrorType::RequestCanceled
182/// [`ErrorType::RequestError`]: crate::error::ErrorType::RequestError
183/// [`ErrorType::RequestTimedOut`]: crate::error::ErrorType::RequestTimedOut
184/// [`ErrorType::Response`]: crate::error::ErrorType::Response
185/// [`Response`]: super::Response
186#[must_use = "futures do nothing unless you `.await` or poll them"]
187pub struct ResponseFuture<T>(Result<Inner<T>, Ready<Error>>);
188
189impl<T> ResponseFuture<T> {
190 pub(crate) fn new(
191 client: HyperClient<Connector, Full<Bytes>>,
192 invalid_token: Option<Arc<AtomicBool>>,
193 request: Request<Full<Bytes>>,
194 span: tracing::Span,
195 timeout: Duration,
196 rate_limiter: Option<RateLimiter>,
197 endpoint: Endpoint,
198 ) -> Self {
199 let permit_generator = rate_limiter.map(|rate_limiter| PermitFutureGenerator {
200 rate_limiter,
201 endpoint,
202 });
203 let response_generator = TimedResponseFutureGenerator {
204 client,
205 request,
206 timeout,
207 };
208 let stage = permit_generator.as_ref().map_or_else(
209 || ResponseStageFuture::Response {
210 fut: response_generator.generate(),
211 permit: None,
212 },
213 |generator| ResponseStageFuture::RateLimitPermit(generator.generate()),
214 );
215 Self(Ok(Inner {
216 invalid_token,
217 permit_generator,
218 phantom: PhantomData,
219 pre_flight_check: None,
220 response_generator,
221 span,
222 stage,
223 }))
224 }
225
226 /// Set a function to call after clearing the ratelimiter but prior to
227 /// sending the request to determine if the request is still valid.
228 ///
229 /// This function will be a no-op if the request has failed, has already
230 /// passed the ratelimiter, or if there is no ratelimiter configured.
231 ///
232 /// Returns whether the pre flight function was set.
233 ///
234 /// # Examples
235 ///
236 /// Delete a message, but immediately before sending the request check if
237 /// the request should still be sent:
238 ///
239 /// ```no_run
240 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
241 /// use std::{
242 /// collections::HashSet,
243 /// env,
244 /// future::IntoFuture,
245 /// sync::{Arc, Mutex},
246 /// };
247 /// use twilight_http::{Client, error::ErrorType};
248 /// use twilight_model::id::Id;
249 ///
250 /// let channel_id = Id::new(1);
251 /// let message_id = Id::new(2);
252 ///
253 /// let channels_ignored = {
254 /// let mut map = HashSet::new();
255 /// map.insert(channel_id);
256 ///
257 /// Arc::new(Mutex::new(map))
258 /// };
259 ///
260 /// let client = Client::new(env::var("DISCORD_TOKEN")?);
261 /// let mut req = client.delete_message(channel_id, message_id).into_future();
262 ///
263 /// let channels_ignored_clone = channels_ignored.clone();
264 /// req.set_pre_flight(move || {
265 /// // imagine you have some logic here to external state that checks
266 /// // whether the request should still be performed
267 /// let channels_ignored = channels_ignored_clone.lock().expect("channels poisoned");
268 ///
269 /// !channels_ignored.contains(&channel_id)
270 /// });
271 ///
272 /// // the pre-flight check will cancel the request
273 /// assert!(matches!(
274 /// req.await.unwrap_err().kind(),
275 /// ErrorType::RequestCanceled,
276 /// ));
277 /// # Ok(()) }
278 /// ```
279 pub fn set_pre_flight<P>(&mut self, predicate: P) -> bool
280 where
281 P: Fn() -> bool + Send + 'static,
282 {
283 if let Ok(inner) = &mut self.0
284 && inner.permit_generator.is_some()
285 && inner.pre_flight_check.is_none()
286 {
287 inner.pre_flight_check = Some(Box::new(predicate));
288
289 true
290 } else {
291 false
292 }
293 }
294
295 /// Creates a future that is immediately ready with an error.
296 pub(crate) fn error(source: Error) -> Self {
297 Self(Err(ready(source)))
298 }
299}
300
301impl<T: Unpin> Future for ResponseFuture<T> {
302 type Output = Result<Response<T>, Error>;
303
304 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
305 let inner = match &mut self.0 {
306 Ok(inner) => inner,
307 Err(err) => return Pin::new(err).poll(cx).map(Err),
308 };
309
310 let _entered = inner.span.enter();
311
312 loop {
313 match &mut inner.stage {
314 ResponseStageFuture::Error { fut, status } => {
315 let body = ready!(Pin::new(fut).poll(cx)).map_err(|source| Error {
316 kind: ErrorType::RequestError,
317 source: Some(Box::new(source)),
318 })?;
319
320 return Poll::Ready(Err(match crate::json::from_bytes::<ApiError>(&body) {
321 Ok(error) => Error {
322 kind: ErrorType::Response {
323 body,
324 error,
325 status: super::StatusCode::new(status.as_u16()),
326 },
327 source: None,
328 },
329 Err(source) => Error {
330 kind: ErrorType::Parsing { body },
331 source: Some(Box::new(source)),
332 },
333 }));
334 }
335 ResponseStageFuture::RateLimitPermit(fut) => {
336 let permit = ready!(Pin::new(fut).poll(cx));
337 if inner
338 .pre_flight_check
339 .as_ref()
340 .is_some_and(|check| !check())
341 {
342 return Poll::Ready(Err(Error {
343 kind: ErrorType::RequestCanceled,
344 source: None,
345 }));
346 }
347
348 inner.stage = ResponseStageFuture::Response {
349 fut: inner.response_generator.generate(),
350 permit: Some(permit),
351 };
352 }
353 ResponseStageFuture::Response { fut, permit } => {
354 let response = ready!(Pin::new(fut).poll(cx))
355 .map_err(|source| Error {
356 kind: ErrorType::RequestTimedOut,
357 source: Some(Box::new(source)),
358 })?
359 .map_err(|source| Error {
360 kind: ErrorType::RequestError,
361 source: Some(Box::new(source)),
362 })?;
363
364 if response.status() == StatusCode::UNAUTHORIZED
365 && let Some(invalid) = &inner.invalid_token
366 {
367 invalid.store(true, Ordering::Relaxed);
368 }
369
370 if let Some(permit) = permit.take() {
371 match parse_ratelimit_headers(response.headers()) {
372 Ok(v) => permit.complete(v),
373 Err(source) => {
374 tracing::warn!("header parsing failed: {source}; {response:?}");
375
376 permit.complete(None);
377 }
378 }
379 }
380
381 if response.status().is_success() {
382 #[cfg(feature = "decompression")]
383 let mut response = response;
384 // Inaccurate since end-users can only access the decompressed body.
385 #[cfg(feature = "decompression")]
386 response.headers_mut().remove(header::CONTENT_LENGTH);
387
388 return Poll::Ready(Ok(Response::new(response)));
389 } else if response.status() == StatusCode::TOO_MANY_REQUESTS {
390 inner.stage = match &inner.permit_generator {
391 Some(generator) => {
392 ResponseStageFuture::RateLimitPermit(generator.generate())
393 }
394 None => ResponseStageFuture::Response {
395 fut: inner.response_generator.generate(),
396 permit: None,
397 },
398 };
399 } else {
400 inner.stage = ResponseStageFuture::Error {
401 status: response.status(),
402 fut: Response::<()>::new(response).bytes(),
403 };
404 }
405 }
406 }
407 }
408 }
409}
410
411/// Internal response future fields.
412struct Inner<T> {
413 /// Whether the client's token is invalidated.
414 invalid_token: Option<Arc<AtomicBool>>,
415 /// Optional [`PermitFuture`] generator, if registered.
416 permit_generator: Option<PermitFutureGenerator>,
417 phantom: PhantomData<T>,
418 /// Predicate to check after completing [`ResponseStageFuture::RateLimitPermit`].
419 pre_flight_check: Option<Box<dyn Fn() -> bool + Send + 'static>>,
420 /// [`Timeout<HyperResponseFuture>`] generator.
421 response_generator: TimedResponseFutureGenerator,
422 /// This future's span.
423 span: tracing::Span,
424 /// This future's current stage.
425 stage: ResponseStageFuture,
426}