tarpc_lib/server/
throttle.rs

1use super::{Channel, Config};
2use crate::{Response, ServerError};
3use futures::{
4    future::AbortRegistration,
5    prelude::*,
6    ready,
7    task::{Context, Poll},
8};
9use log::debug;
10use pin_utils::{unsafe_pinned, unsafe_unpinned};
11use std::{io, pin::Pin};
12
13/// A [`Channel`] that limits the number of concurrent
14/// requests by throttling.
15#[derive(Debug)]
16pub struct Throttler<C> {
17    max_in_flight_requests: usize,
18    inner: C,
19}
20
21impl<C> Throttler<C> {
22    unsafe_unpinned!(max_in_flight_requests: usize);
23    unsafe_pinned!(inner: C);
24
25    /// Returns the inner channel.
26    pub fn get_ref(&self) -> &C {
27        &self.inner
28    }
29}
30
31impl<C> Throttler<C>
32where
33    C: Channel,
34{
35    /// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
36    /// `max_in_flight_requests`.
37    pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
38        Throttler {
39            inner,
40            max_in_flight_requests,
41        }
42    }
43}
44
45impl<C> Stream for Throttler<C>
46where
47    C: Channel,
48{
49    type Item = <C as Stream>::Item;
50
51    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
52        while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() {
53            ready!(self.as_mut().inner().poll_ready(cx)?);
54
55            match ready!(self.as_mut().inner().poll_next(cx)?) {
56                Some(request) => {
57                    debug!(
58                        "[{}] Client has reached in-flight request limit ({}/{}).",
59                        request.context.trace_id(),
60                        self.as_mut().in_flight_requests(),
61                        self.as_mut().max_in_flight_requests(),
62                    );
63
64                    self.as_mut().start_send(Response {
65                        request_id: request.id,
66                        message: Err(ServerError {
67                            kind: io::ErrorKind::WouldBlock,
68                            detail: Some("Server throttled the request.".into()),
69                        }),
70                    })?;
71                }
72                None => return Poll::Ready(None),
73            }
74        }
75        self.inner().poll_next(cx)
76    }
77}
78
79impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
80where
81    C: Channel,
82{
83    type Error = io::Error;
84
85    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
86        self.inner().poll_ready(cx)
87    }
88
89    fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
90        self.inner().start_send(item)
91    }
92
93    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
94        self.inner().poll_flush(cx)
95    }
96
97    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
98        self.inner().poll_close(cx)
99    }
100}
101
102impl<C> AsRef<C> for Throttler<C> {
103    fn as_ref(&self) -> &C {
104        &self.inner
105    }
106}
107
108impl<C> Channel for Throttler<C>
109where
110    C: Channel,
111{
112    type Req = <C as Channel>::Req;
113    type Resp = <C as Channel>::Resp;
114
115    fn in_flight_requests(self: Pin<&mut Self>) -> usize {
116        self.inner().in_flight_requests()
117    }
118
119    fn config(&self) -> &Config {
120        self.inner.config()
121    }
122
123    fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
124        self.inner().start_request(request_id)
125    }
126}
127
128/// A stream of throttling channels.
129#[derive(Debug)]
130pub struct ThrottlerStream<S> {
131    inner: S,
132    max_in_flight_requests: usize,
133}
134
135impl<S> ThrottlerStream<S>
136where
137    S: Stream,
138    <S as Stream>::Item: Channel,
139{
140    unsafe_pinned!(inner: S);
141    unsafe_unpinned!(max_in_flight_requests: usize);
142
143    pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
144        Self {
145            inner,
146            max_in_flight_requests,
147        }
148    }
149}
150
151impl<S> Stream for ThrottlerStream<S>
152where
153    S: Stream,
154    <S as Stream>::Item: Channel,
155{
156    type Item = Throttler<<S as Stream>::Item>;
157
158    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
159        match ready!(self.as_mut().inner().poll_next(cx)) {
160            Some(channel) => Poll::Ready(Some(Throttler::new(
161                channel,
162                *self.max_in_flight_requests(),
163            ))),
164            None => Poll::Ready(None),
165        }
166    }
167}
168
169#[cfg(test)]
170use super::testing::{self, FakeChannel, PollExt};
171#[cfg(test)]
172use crate::Request;
173#[cfg(test)]
174use pin_utils::pin_mut;
175#[cfg(test)]
176use std::marker::PhantomData;
177
178#[test]
179fn throttler_in_flight_requests() {
180    let throttler = Throttler {
181        max_in_flight_requests: 0,
182        inner: FakeChannel::default::<isize, isize>(),
183    };
184
185    pin_mut!(throttler);
186    for i in 0..5 {
187        throttler.inner.in_flight_requests.insert(i);
188    }
189    assert_eq!(throttler.as_mut().in_flight_requests(), 5);
190}
191
192#[test]
193fn throttler_start_request() {
194    let throttler = Throttler {
195        max_in_flight_requests: 0,
196        inner: FakeChannel::default::<isize, isize>(),
197    };
198
199    pin_mut!(throttler);
200    throttler.as_mut().start_request(1);
201    assert_eq!(throttler.inner.in_flight_requests.len(), 1);
202}
203
204#[test]
205fn throttler_poll_next_done() {
206    let throttler = Throttler {
207        max_in_flight_requests: 0,
208        inner: FakeChannel::default::<isize, isize>(),
209    };
210
211    pin_mut!(throttler);
212    assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
213}
214
215#[test]
216fn throttler_poll_next_some() -> io::Result<()> {
217    let throttler = Throttler {
218        max_in_flight_requests: 1,
219        inner: FakeChannel::default::<isize, isize>(),
220    };
221
222    pin_mut!(throttler);
223    throttler.inner.push_req(0, 1);
224    assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
225    assert_eq!(
226        throttler
227            .as_mut()
228            .poll_next(&mut testing::cx())?
229            .map(|r| r.map(|r| (r.id, r.message))),
230        Poll::Ready(Some((0, 1)))
231    );
232    Ok(())
233}
234
235#[test]
236fn throttler_poll_next_throttled() {
237    let throttler = Throttler {
238        max_in_flight_requests: 0,
239        inner: FakeChannel::default::<isize, isize>(),
240    };
241
242    pin_mut!(throttler);
243    throttler.inner.push_req(1, 1);
244    assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
245    assert_eq!(throttler.inner.sink.len(), 1);
246    let resp = throttler.inner.sink.get(0).unwrap();
247    assert_eq!(resp.request_id, 1);
248    assert!(resp.message.is_err());
249}
250
251#[test]
252fn throttler_poll_next_throttled_sink_not_ready() {
253    let throttler = Throttler {
254        max_in_flight_requests: 0,
255        inner: PendingSink::default::<isize, isize>(),
256    };
257    pin_mut!(throttler);
258    assert!(throttler.poll_next(&mut testing::cx()).is_pending());
259
260    struct PendingSink<In, Out> {
261        ghost: PhantomData<fn(Out) -> In>,
262    }
263    impl PendingSink<(), ()> {
264        pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
265            PendingSink { ghost: PhantomData }
266        }
267    }
268    impl<In, Out> Stream for PendingSink<In, Out> {
269        type Item = In;
270        fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
271            unimplemented!()
272        }
273    }
274    impl<In, Out> Sink<Out> for PendingSink<In, Out> {
275        type Error = io::Error;
276        fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
277            Poll::Pending
278        }
279        fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
280            Err(io::Error::from(io::ErrorKind::WouldBlock))
281        }
282        fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
283            Poll::Pending
284        }
285        fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
286            Poll::Pending
287        }
288    }
289    impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
290        type Req = Req;
291        type Resp = Resp;
292        fn config(&self) -> &Config {
293            unimplemented!()
294        }
295        fn in_flight_requests(self: Pin<&mut Self>) -> usize {
296            0
297        }
298        fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
299            unimplemented!()
300        }
301    }
302}
303
304#[test]
305fn throttler_start_send() {
306    let throttler = Throttler {
307        max_in_flight_requests: 0,
308        inner: FakeChannel::default::<isize, isize>(),
309    };
310
311    pin_mut!(throttler);
312    throttler.inner.in_flight_requests.insert(0);
313    throttler
314        .as_mut()
315        .start_send(Response {
316            request_id: 0,
317            message: Ok(1),
318        })
319        .unwrap();
320    assert!(throttler.inner.in_flight_requests.is_empty());
321    assert_eq!(
322        throttler.inner.sink.get(0),
323        Some(&Response {
324            request_id: 0,
325            message: Ok(1),
326        })
327    );
328}