1use super::{Channel, Config};
2use crate::{Response, ServerError};
3use futures::{
4 future::AbortRegistration,
5 prelude::*,
6 ready,
7 task::{Context, Poll},
8};
9use log::debug;
10use pin_utils::{unsafe_pinned, unsafe_unpinned};
11use std::{io, pin::Pin};
12
13#[derive(Debug)]
16pub struct Throttler<C> {
17 max_in_flight_requests: usize,
18 inner: C,
19}
20
21impl<C> Throttler<C> {
22 unsafe_unpinned!(max_in_flight_requests: usize);
23 unsafe_pinned!(inner: C);
24
25 pub fn get_ref(&self) -> &C {
27 &self.inner
28 }
29}
30
31impl<C> Throttler<C>
32where
33 C: Channel,
34{
35 pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
38 Throttler {
39 inner,
40 max_in_flight_requests,
41 }
42 }
43}
44
45impl<C> Stream for Throttler<C>
46where
47 C: Channel,
48{
49 type Item = <C as Stream>::Item;
50
51 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
52 while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() {
53 ready!(self.as_mut().inner().poll_ready(cx)?);
54
55 match ready!(self.as_mut().inner().poll_next(cx)?) {
56 Some(request) => {
57 debug!(
58 "[{}] Client has reached in-flight request limit ({}/{}).",
59 request.context.trace_id(),
60 self.as_mut().in_flight_requests(),
61 self.as_mut().max_in_flight_requests(),
62 );
63
64 self.as_mut().start_send(Response {
65 request_id: request.id,
66 message: Err(ServerError {
67 kind: io::ErrorKind::WouldBlock,
68 detail: Some("Server throttled the request.".into()),
69 }),
70 })?;
71 }
72 None => return Poll::Ready(None),
73 }
74 }
75 self.inner().poll_next(cx)
76 }
77}
78
79impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
80where
81 C: Channel,
82{
83 type Error = io::Error;
84
85 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
86 self.inner().poll_ready(cx)
87 }
88
89 fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
90 self.inner().start_send(item)
91 }
92
93 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
94 self.inner().poll_flush(cx)
95 }
96
97 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
98 self.inner().poll_close(cx)
99 }
100}
101
102impl<C> AsRef<C> for Throttler<C> {
103 fn as_ref(&self) -> &C {
104 &self.inner
105 }
106}
107
108impl<C> Channel for Throttler<C>
109where
110 C: Channel,
111{
112 type Req = <C as Channel>::Req;
113 type Resp = <C as Channel>::Resp;
114
115 fn in_flight_requests(self: Pin<&mut Self>) -> usize {
116 self.inner().in_flight_requests()
117 }
118
119 fn config(&self) -> &Config {
120 self.inner.config()
121 }
122
123 fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
124 self.inner().start_request(request_id)
125 }
126}
127
128#[derive(Debug)]
130pub struct ThrottlerStream<S> {
131 inner: S,
132 max_in_flight_requests: usize,
133}
134
135impl<S> ThrottlerStream<S>
136where
137 S: Stream,
138 <S as Stream>::Item: Channel,
139{
140 unsafe_pinned!(inner: S);
141 unsafe_unpinned!(max_in_flight_requests: usize);
142
143 pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
144 Self {
145 inner,
146 max_in_flight_requests,
147 }
148 }
149}
150
151impl<S> Stream for ThrottlerStream<S>
152where
153 S: Stream,
154 <S as Stream>::Item: Channel,
155{
156 type Item = Throttler<<S as Stream>::Item>;
157
158 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
159 match ready!(self.as_mut().inner().poll_next(cx)) {
160 Some(channel) => Poll::Ready(Some(Throttler::new(
161 channel,
162 *self.max_in_flight_requests(),
163 ))),
164 None => Poll::Ready(None),
165 }
166 }
167}
168
169#[cfg(test)]
170use super::testing::{self, FakeChannel, PollExt};
171#[cfg(test)]
172use crate::Request;
173#[cfg(test)]
174use pin_utils::pin_mut;
175#[cfg(test)]
176use std::marker::PhantomData;
177
178#[test]
179fn throttler_in_flight_requests() {
180 let throttler = Throttler {
181 max_in_flight_requests: 0,
182 inner: FakeChannel::default::<isize, isize>(),
183 };
184
185 pin_mut!(throttler);
186 for i in 0..5 {
187 throttler.inner.in_flight_requests.insert(i);
188 }
189 assert_eq!(throttler.as_mut().in_flight_requests(), 5);
190}
191
192#[test]
193fn throttler_start_request() {
194 let throttler = Throttler {
195 max_in_flight_requests: 0,
196 inner: FakeChannel::default::<isize, isize>(),
197 };
198
199 pin_mut!(throttler);
200 throttler.as_mut().start_request(1);
201 assert_eq!(throttler.inner.in_flight_requests.len(), 1);
202}
203
204#[test]
205fn throttler_poll_next_done() {
206 let throttler = Throttler {
207 max_in_flight_requests: 0,
208 inner: FakeChannel::default::<isize, isize>(),
209 };
210
211 pin_mut!(throttler);
212 assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
213}
214
215#[test]
216fn throttler_poll_next_some() -> io::Result<()> {
217 let throttler = Throttler {
218 max_in_flight_requests: 1,
219 inner: FakeChannel::default::<isize, isize>(),
220 };
221
222 pin_mut!(throttler);
223 throttler.inner.push_req(0, 1);
224 assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
225 assert_eq!(
226 throttler
227 .as_mut()
228 .poll_next(&mut testing::cx())?
229 .map(|r| r.map(|r| (r.id, r.message))),
230 Poll::Ready(Some((0, 1)))
231 );
232 Ok(())
233}
234
235#[test]
236fn throttler_poll_next_throttled() {
237 let throttler = Throttler {
238 max_in_flight_requests: 0,
239 inner: FakeChannel::default::<isize, isize>(),
240 };
241
242 pin_mut!(throttler);
243 throttler.inner.push_req(1, 1);
244 assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
245 assert_eq!(throttler.inner.sink.len(), 1);
246 let resp = throttler.inner.sink.get(0).unwrap();
247 assert_eq!(resp.request_id, 1);
248 assert!(resp.message.is_err());
249}
250
251#[test]
252fn throttler_poll_next_throttled_sink_not_ready() {
253 let throttler = Throttler {
254 max_in_flight_requests: 0,
255 inner: PendingSink::default::<isize, isize>(),
256 };
257 pin_mut!(throttler);
258 assert!(throttler.poll_next(&mut testing::cx()).is_pending());
259
260 struct PendingSink<In, Out> {
261 ghost: PhantomData<fn(Out) -> In>,
262 }
263 impl PendingSink<(), ()> {
264 pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
265 PendingSink { ghost: PhantomData }
266 }
267 }
268 impl<In, Out> Stream for PendingSink<In, Out> {
269 type Item = In;
270 fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
271 unimplemented!()
272 }
273 }
274 impl<In, Out> Sink<Out> for PendingSink<In, Out> {
275 type Error = io::Error;
276 fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
277 Poll::Pending
278 }
279 fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
280 Err(io::Error::from(io::ErrorKind::WouldBlock))
281 }
282 fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
283 Poll::Pending
284 }
285 fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
286 Poll::Pending
287 }
288 }
289 impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
290 type Req = Req;
291 type Resp = Resp;
292 fn config(&self) -> &Config {
293 unimplemented!()
294 }
295 fn in_flight_requests(self: Pin<&mut Self>) -> usize {
296 0
297 }
298 fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
299 unimplemented!()
300 }
301 }
302}
303
304#[test]
305fn throttler_start_send() {
306 let throttler = Throttler {
307 max_in_flight_requests: 0,
308 inner: FakeChannel::default::<isize, isize>(),
309 };
310
311 pin_mut!(throttler);
312 throttler.inner.in_flight_requests.insert(0);
313 throttler
314 .as_mut()
315 .start_send(Response {
316 request_id: 0,
317 message: Ok(1),
318 })
319 .unwrap();
320 assert!(throttler.inner.in_flight_requests.is_empty());
321 assert_eq!(
322 throttler.inner.sink.get(0),
323 Some(&Response {
324 request_id: 0,
325 message: Ok(1),
326 })
327 );
328}