teloxide_core/adaptors/throttle/request.rs
1use std::{
2 future::{Future, IntoFuture},
3 pin::Pin,
4 sync::Arc,
5 time::Instant,
6};
7
8use futures::{
9 future::BoxFuture,
10 task::{Context, Poll},
11};
12use tokio::sync::mpsc;
13
14use crate::{
15 adaptors::throttle::{channel, ChatIdHash, FreezeUntil, RequestLock},
16 errors::AsResponseParameters,
17 requests::{HasPayload, Output, Request},
18};
19
20/// Request returned by [`Throttling`](crate::adaptors::Throttle) methods.
21#[must_use = "Requests are lazy and do nothing unless sent"]
22#[derive(Clone)]
23pub struct ThrottlingRequest<R: HasPayload> {
24 pub(super) request: Arc<R>,
25 pub(super) chat_id: fn(&R::Payload) -> ChatIdHash,
26 pub(super) worker: mpsc::Sender<(ChatIdHash, RequestLock)>,
27}
28
29/// Future returned by [`ThrottlingRequest`]s.
30#[pin_project::pin_project]
31pub struct ThrottlingSend<R: Request>(#[pin] BoxFuture<'static, Result<Output<R>, R::Err>>);
32
33enum ShareableRequest<R> {
34 Shared(Arc<R>),
35 // Option is used to `take` ownership
36 Owned(Option<R>),
37}
38
39impl<R: HasPayload + Clone> HasPayload for ThrottlingRequest<R> {
40 type Payload = R::Payload;
41
42 /// Note that if this request was already executed via `send_ref` and it
43 /// didn't yet completed, this method will clone the underlying request.
44 fn payload_mut(&mut self) -> &mut Self::Payload {
45 Arc::make_mut(&mut self.request).payload_mut()
46 }
47
48 fn payload_ref(&self) -> &Self::Payload {
49 self.request.payload_ref()
50 }
51}
52
53impl<R> Request for ThrottlingRequest<R>
54where
55 R: Request + Clone + Send + Sync + 'static, // TODO: rem static
56 R::Err: AsResponseParameters + Send,
57 Output<R>: Send,
58{
59 type Err = R::Err;
60 type Send = ThrottlingSend<R>;
61 type SendRef = ThrottlingSend<R>;
62
63 fn send(self) -> Self::Send {
64 let chat = (self.chat_id)(self.payload_ref());
65 let request = match Arc::try_unwrap(self.request) {
66 Ok(owned) => ShareableRequest::Owned(Some(owned)),
67 Err(shared) => ShareableRequest::Shared(shared),
68 };
69 let fut = send(request, chat, self.worker);
70
71 ThrottlingSend(Box::pin(fut))
72 }
73
74 fn send_ref(&self) -> Self::SendRef {
75 let chat = (self.chat_id)(self.payload_ref());
76 let request = ShareableRequest::Shared(Arc::clone(&self.request));
77 let fut = send(request, chat, self.worker.clone());
78
79 ThrottlingSend(Box::pin(fut))
80 }
81}
82
83impl<R> IntoFuture for ThrottlingRequest<R>
84where
85 R: Request + Clone + Send + Sync + 'static,
86 R::Err: AsResponseParameters + Send,
87 Output<R>: Send,
88{
89 type Output = Result<Output<Self>, <Self as Request>::Err>;
90 type IntoFuture = <Self as Request>::Send;
91
92 fn into_future(self) -> Self::IntoFuture {
93 self.send()
94 }
95}
96
97impl<R: Request> Future for ThrottlingSend<R>
98where
99 R::Err: AsResponseParameters,
100{
101 type Output = Result<Output<R>, R::Err>;
102
103 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
104 self.as_mut().project().0.poll(cx)
105 }
106}
107
108// This diagram explains how `ThrottlingRequest` works/what `send` does
109//
110// │
111// ThrottlingRequest │ worker()
112// │
113// ┌───────────────┐ │ ┌────────────────────────┐
114// ┌──────────────────►│request is sent│ │ │see worker documentation│
115// │ └───────┬───────┘ │ │and comments for more │
116// │ │ │ │information on how it │
117// │ ▼ │ │actually works │
118// │ ┌─────────┐ │ └────────────────────────┘
119// │ ┌────────────────┐ │send lock│ │
120// │ │has worker died?│◄──┤to worker├─────►:───────────┐
121// │ └─┬─────────────┬┘ └─────────┘ │ ▼
122// │ │ │ │ ┌──────────────────┐
123// │ Y └─N───────┐ │ │ *magic* │
124// │ │ │ │ └────────┬─────────┘
125// │ ▼ ▼ │ │
126// │ ┌───────────┐ ┌────────────────┐ │ ▼
127// │ │send inner │ │wait for worker │ │ ┌─────────────────┐
128// │ │request │ │to allow sending│◄──:◄─┤ `lock.unlock()` │
129// │ └───┬───────┘ │this request │ │ └─────────────────┘
130// │ │ └────────┬───────┘ │
131// │ │ │ │
132// │ ▼ ▼ │
133// │ ┌──────┐ ┌────────────────────┐ │
134// │ │return│ │send inner request │ │
135// │ │result│ │and check its result│ │
136// │ └──────┘ └─┬─────────┬────────┘ │
137// │ ▲ ▲ │ │ │
138// │ │ │ │ Err(RetryAfter(n)) │
139// │ │ │ else │ │
140// │ │ │ │ ▼ │
141// │ │ └─────┘ ┌───────────────┐ │
142// │ │ │are retries on?│ │
143// │ │ └┬─────────────┬┘ │
144// │ │ │ │ │
145// │ └────────────N─┘ Y │
146// │ │ │ ┌──────────────────┐
147// │ ▼ │ │ *magic* │
148// │ ┌──────────────────┐ │ └──────────────────┘
149// ┌┴────────────┐ │notify worker that│ │ ▲
150// │retry request│◄──┤RetryAfter error ├──►:───────────┘
151// └─────────────┘ │has happened │ │
152// └──────────────────┘ │
153// │
154
155/// Actual implementation of the `ThrottlingSend` future
156async fn send<R>(
157 mut request: ShareableRequest<R>,
158 chat: ChatIdHash,
159 worker: mpsc::Sender<(ChatIdHash, RequestLock)>,
160) -> Result<Output<R>, R::Err>
161where
162 R: Request + Send + Sync + 'static,
163 R::Err: AsResponseParameters + Send,
164 Output<R>: Send,
165{
166 // We use option in `ShareableRequest` to `take` when sending by value.
167 //
168 // All unwraps down below will succeed because we always return immediately
169 // after taking.
170
171 loop {
172 let (lock, wait) = channel();
173
174 // The worker is unlikely to drop queue before sending all requests,
175 // but just in case it has dropped the queue, we want to just send the
176 // request.
177 if worker.send((chat, lock)).await.is_err() {
178 log::error!("Worker dropped the queue before sending all requests");
179
180 let res = match &mut request {
181 ShareableRequest::Shared(shared) => shared.send_ref().await,
182 ShareableRequest::Owned(owned) => owned.take().unwrap().await,
183 };
184
185 return res;
186 };
187
188 let (retry, freeze) = wait.await;
189
190 let res = match (retry, &mut request) {
191 // Retries are turned on, use `send_ref` even if we have owned access
192 (true, request) => {
193 let request = match request {
194 ShareableRequest::Shared(shared) => &**shared,
195 ShareableRequest::Owned(owned) => owned.as_ref().unwrap(),
196 };
197
198 request.send_ref().await
199 }
200 (false, ShareableRequest::Shared(shared)) => shared.send_ref().await,
201 (false, ShareableRequest::Owned(owned)) => owned.take().unwrap().await,
202 };
203
204 let retry_after = res.as_ref().err().and_then(<_>::retry_after);
205 if let Some(retry_after) = retry_after {
206 let after = retry_after.duration();
207 let until = Instant::now() + after;
208
209 // If we'll retry, we check that worker hasn't died at the start of the loop
210 // otherwise we don't care if the worker is alive or not
211 let _ = freeze.send(FreezeUntil { until, after, chat }).await;
212
213 if retry {
214 log::warn!("Freezing, before retrying: {retry_after:?}");
215 tokio::time::sleep_until(until.into()).await;
216 }
217 }
218
219 match res {
220 Err(_) if retry && retry_after.is_some() => continue,
221 res => break res,
222 };
223 }
224}