teloxide_core/adaptors/throttle/
request.rs

1use std::{
2    future::{Future, IntoFuture},
3    pin::Pin,
4    sync::Arc,
5    time::Instant,
6};
7
8use futures::{
9    future::BoxFuture,
10    task::{Context, Poll},
11};
12use tokio::sync::mpsc;
13
14use crate::{
15    adaptors::throttle::{channel, ChatIdHash, FreezeUntil, RequestLock},
16    errors::AsResponseParameters,
17    requests::{HasPayload, Output, Request},
18};
19
20/// Request returned by [`Throttling`](crate::adaptors::Throttle) methods.
21#[must_use = "Requests are lazy and do nothing unless sent"]
22#[derive(Clone)]
23pub struct ThrottlingRequest<R: HasPayload> {
24    pub(super) request: Arc<R>,
25    pub(super) chat_id: fn(&R::Payload) -> ChatIdHash,
26    pub(super) worker: mpsc::Sender<(ChatIdHash, RequestLock)>,
27}
28
29/// Future returned by [`ThrottlingRequest`]s.
30#[pin_project::pin_project]
31pub struct ThrottlingSend<R: Request>(#[pin] BoxFuture<'static, Result<Output<R>, R::Err>>);
32
33enum ShareableRequest<R> {
34    Shared(Arc<R>),
35    // Option is used to `take` ownership
36    Owned(Option<R>),
37}
38
39impl<R: HasPayload + Clone> HasPayload for ThrottlingRequest<R> {
40    type Payload = R::Payload;
41
42    /// Note that if this request was already executed via `send_ref` and it
43    /// didn't yet completed, this method will clone the underlying request.
44    fn payload_mut(&mut self) -> &mut Self::Payload {
45        Arc::make_mut(&mut self.request).payload_mut()
46    }
47
48    fn payload_ref(&self) -> &Self::Payload {
49        self.request.payload_ref()
50    }
51}
52
53impl<R> Request for ThrottlingRequest<R>
54where
55    R: Request + Clone + Send + Sync + 'static, // TODO: rem static
56    R::Err: AsResponseParameters + Send,
57    Output<R>: Send,
58{
59    type Err = R::Err;
60    type Send = ThrottlingSend<R>;
61    type SendRef = ThrottlingSend<R>;
62
63    fn send(self) -> Self::Send {
64        let chat = (self.chat_id)(self.payload_ref());
65        let request = match Arc::try_unwrap(self.request) {
66            Ok(owned) => ShareableRequest::Owned(Some(owned)),
67            Err(shared) => ShareableRequest::Shared(shared),
68        };
69        let fut = send(request, chat, self.worker);
70
71        ThrottlingSend(Box::pin(fut))
72    }
73
74    fn send_ref(&self) -> Self::SendRef {
75        let chat = (self.chat_id)(self.payload_ref());
76        let request = ShareableRequest::Shared(Arc::clone(&self.request));
77        let fut = send(request, chat, self.worker.clone());
78
79        ThrottlingSend(Box::pin(fut))
80    }
81}
82
83impl<R> IntoFuture for ThrottlingRequest<R>
84where
85    R: Request + Clone + Send + Sync + 'static,
86    R::Err: AsResponseParameters + Send,
87    Output<R>: Send,
88{
89    type Output = Result<Output<Self>, <Self as Request>::Err>;
90    type IntoFuture = <Self as Request>::Send;
91
92    fn into_future(self) -> Self::IntoFuture {
93        self.send()
94    }
95}
96
97impl<R: Request> Future for ThrottlingSend<R>
98where
99    R::Err: AsResponseParameters,
100{
101    type Output = Result<Output<R>, R::Err>;
102
103    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
104        self.as_mut().project().0.poll(cx)
105    }
106}
107
108// This diagram explains how `ThrottlingRequest` works/what `send` does
109//
110//                                          │
111//                      ThrottlingRequest   │   worker()
112//                                          │
113//                      ┌───────────────┐   │  ┌────────────────────────┐
114//  ┌──────────────────►│request is sent│   │  │see worker documentation│
115//  │                   └───────┬───────┘   │  │and comments for more   │
116//  │                           │           │  │information on how it   │
117//  │                           ▼           │  │actually works          │
118//  │                      ┌─────────┐      │  └────────────────────────┘
119//  │ ┌────────────────┐   │send lock│      │
120//  │ │has worker died?│◄──┤to worker├─────►:───────────┐
121//  │ └─┬─────────────┬┘   └─────────┘      │           ▼
122//  │   │             │                     │  ┌──────────────────┐
123//  │   Y             └─N───────┐           │  │     *magic*      │
124//  │   │                       │           │  └────────┬─────────┘
125//  │   ▼                       ▼           │           │
126//  │ ┌───────────┐    ┌────────────────┐   │           ▼
127//  │ │send inner │    │wait for worker │   │  ┌─────────────────┐
128//  │ │request    │    │to allow sending│◄──:◄─┤ `lock.unlock()` │
129//  │ └───┬───────┘    │this request    │   │  └─────────────────┘
130//  │     │            └────────┬───────┘   │
131//  │     │                     │           │
132//  │     ▼                     ▼           │
133//  │    ┌──────┐  ┌────────────────────┐   │
134//  │    │return│  │send inner request  │   │
135//  │    │result│  │and check its result│   │
136//  │    └──────┘  └─┬─────────┬────────┘   │
137//  │     ▲    ▲     │         │            │
138//  │     │    │     │ Err(RetryAfter(n))   │
139//  │     │    │   else        │            │
140//  │     │    │     │         ▼            │
141//  │     │    └─────┘  ┌───────────────┐   │
142//  │     │             │are retries on?│   │
143//  │     │             └┬─────────────┬┘   │
144//  │     │              │             │    │
145//  │     └────────────N─┘             Y    │
146//  │                                  │    │  ┌──────────────────┐
147//  │                                  ▼    │  │     *magic*      │
148//  │                ┌──────────────────┐   │  └──────────────────┘
149// ┌┴────────────┐   │notify worker that│   │           ▲
150// │retry request│◄──┤RetryAfter error  ├──►:───────────┘
151// └─────────────┘   │has happened      │   │
152//                   └──────────────────┘   │
153//                                          │
154
155/// Actual implementation of the `ThrottlingSend` future
156async fn send<R>(
157    mut request: ShareableRequest<R>,
158    chat: ChatIdHash,
159    worker: mpsc::Sender<(ChatIdHash, RequestLock)>,
160) -> Result<Output<R>, R::Err>
161where
162    R: Request + Send + Sync + 'static,
163    R::Err: AsResponseParameters + Send,
164    Output<R>: Send,
165{
166    // We use option in `ShareableRequest` to `take` when sending by value.
167    //
168    // All unwraps down below will succeed because we always return immediately
169    // after taking.
170
171    loop {
172        let (lock, wait) = channel();
173
174        // The worker is unlikely to drop queue before sending all requests,
175        // but just in case it has dropped the queue, we want to just send the
176        // request.
177        if worker.send((chat, lock)).await.is_err() {
178            log::error!("Worker dropped the queue before sending all requests");
179
180            let res = match &mut request {
181                ShareableRequest::Shared(shared) => shared.send_ref().await,
182                ShareableRequest::Owned(owned) => owned.take().unwrap().await,
183            };
184
185            return res;
186        };
187
188        let (retry, freeze) = wait.await;
189
190        let res = match (retry, &mut request) {
191            // Retries are turned on, use `send_ref` even if we have owned access
192            (true, request) => {
193                let request = match request {
194                    ShareableRequest::Shared(shared) => &**shared,
195                    ShareableRequest::Owned(owned) => owned.as_ref().unwrap(),
196                };
197
198                request.send_ref().await
199            }
200            (false, ShareableRequest::Shared(shared)) => shared.send_ref().await,
201            (false, ShareableRequest::Owned(owned)) => owned.take().unwrap().await,
202        };
203
204        let retry_after = res.as_ref().err().and_then(<_>::retry_after);
205        if let Some(retry_after) = retry_after {
206            let after = retry_after.duration();
207            let until = Instant::now() + after;
208
209            // If we'll retry, we check that worker hasn't died at the start of the loop
210            // otherwise we don't care if the worker is alive or not
211            let _ = freeze.send(FreezeUntil { until, after, chat }).await;
212
213            if retry {
214                log::warn!("Freezing, before retrying: {retry_after:?}");
215                tokio::time::sleep_until(until.into()).await;
216            }
217        }
218
219        match res {
220            Err(_) if retry && retry_after.is_some() => continue,
221            res => break res,
222        };
223    }
224}