tarpc_lib/server/
filter.rs

1// Copyright 2018 Google LLC
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT.
6
7use crate::{
8    server::{self, Channel},
9    util::Compact,
10};
11use fnv::FnvHashMap;
12use futures::{
13    channel::mpsc,
14    future::AbortRegistration,
15    prelude::*,
16    ready,
17    stream::Fuse,
18    task::{Context, Poll},
19};
20use log::{debug, info, trace};
21use pin_utils::{unsafe_pinned, unsafe_unpinned};
22use std::sync::{Arc, Weak};
23use std::{
24    collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
25};
26
27/// A single-threaded filter that drops channels based on per-key limits.
28#[derive(Debug)]
29pub struct ChannelFilter<S, K, F>
30where
31    K: Eq + Hash,
32{
33    listener: Fuse<S>,
34    channels_per_key: u32,
35    dropped_keys: mpsc::UnboundedReceiver<K>,
36    dropped_keys_tx: mpsc::UnboundedSender<K>,
37    key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
38    keymaker: F,
39}
40
41/// A channel that is tracked by a ChannelFilter.
42#[derive(Debug)]
43pub struct TrackedChannel<C, K> {
44    inner: C,
45    tracker: Arc<Tracker<K>>,
46}
47
48impl<C, K> TrackedChannel<C, K> {
49    unsafe_pinned!(inner: C);
50}
51
52#[derive(Debug)]
53struct Tracker<K> {
54    key: Option<K>,
55    dropped_keys: mpsc::UnboundedSender<K>,
56}
57
58impl<K> Drop for Tracker<K> {
59    fn drop(&mut self) {
60        // Don't care if the listener is dropped.
61        let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
62    }
63}
64
65impl<C, K> Stream for TrackedChannel<C, K>
66where
67    C: Stream,
68{
69    type Item = <C as Stream>::Item;
70
71    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
72        self.channel().poll_next(cx)
73    }
74}
75
76impl<C, I, K> Sink<I> for TrackedChannel<C, K>
77where
78    C: Sink<I>,
79{
80    type Error = C::Error;
81
82    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
83        self.channel().poll_ready(cx)
84    }
85
86    fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
87        self.channel().start_send(item)
88    }
89
90    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
91        self.channel().poll_flush(cx)
92    }
93
94    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
95        self.channel().poll_close(cx)
96    }
97}
98
99impl<C, K> AsRef<C> for TrackedChannel<C, K> {
100    fn as_ref(&self) -> &C {
101        &self.inner
102    }
103}
104
105impl<C, K> Channel for TrackedChannel<C, K>
106where
107    C: Channel,
108{
109    type Req = C::Req;
110    type Resp = C::Resp;
111
112    fn config(&self) -> &server::Config {
113        self.inner.config()
114    }
115
116    fn in_flight_requests(self: Pin<&mut Self>) -> usize {
117        self.inner().in_flight_requests()
118    }
119
120    fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
121        self.inner().start_request(request_id)
122    }
123}
124
125impl<C, K> TrackedChannel<C, K> {
126    /// Returns the inner channel.
127    pub fn get_ref(&self) -> &C {
128        &self.inner
129    }
130
131    /// Returns the pinned inner channel.
132    fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
133        self.inner()
134    }
135}
136
137impl<S, K, F> ChannelFilter<S, K, F>
138where
139    K: fmt::Display + Eq + Hash + Clone,
140{
141    unsafe_pinned!(listener: Fuse<S>);
142    unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
143    unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
144    unsafe_unpinned!(key_counts: FnvHashMap<K, Weak<Tracker<K>>>);
145    unsafe_unpinned!(channels_per_key: u32);
146    unsafe_unpinned!(keymaker: F);
147}
148
149impl<S, K, F> ChannelFilter<S, K, F>
150where
151    K: Eq + Hash,
152    S: Stream,
153    F: Fn(&S::Item) -> K,
154{
155    /// Sheds new channels to stay under configured limits.
156    pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
157        let (dropped_keys_tx, dropped_keys) = mpsc::unbounded();
158        ChannelFilter {
159            listener: listener.fuse(),
160            channels_per_key,
161            dropped_keys,
162            dropped_keys_tx,
163            key_counts: FnvHashMap::default(),
164            keymaker,
165        }
166    }
167}
168
169impl<S, K, F> ChannelFilter<S, K, F>
170where
171    S: Stream,
172    K: fmt::Display + Eq + Hash + Clone + Unpin,
173    F: Fn(&S::Item) -> K,
174{
175    fn handle_new_channel(
176        mut self: Pin<&mut Self>,
177        stream: S::Item,
178    ) -> Result<TrackedChannel<S::Item, K>, K> {
179        let key = self.as_mut().keymaker()(&stream);
180        let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
181
182        trace!(
183            "[{}] Opening channel ({}/{}) channels for key.",
184            key,
185            Arc::strong_count(&tracker),
186            self.as_mut().channels_per_key()
187        );
188
189        Ok(TrackedChannel {
190            tracker,
191            inner: stream,
192        })
193    }
194
195    fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
196        let channels_per_key = self.channels_per_key;
197        let dropped_keys = self.dropped_keys_tx.clone();
198        let key_counts = &mut self.as_mut().key_counts();
199        match key_counts.entry(key.clone()) {
200            Entry::Vacant(vacant) => {
201                let tracker = Arc::new(Tracker {
202                    key: Some(key),
203                    dropped_keys,
204                });
205
206                vacant.insert(Arc::downgrade(&tracker));
207                Ok(tracker)
208            }
209            Entry::Occupied(mut o) => {
210                let count = o.get().strong_count();
211                if count >= channels_per_key.try_into().unwrap() {
212                    info!(
213                        "[{}] Opened max channels from key ({}/{}).",
214                        key, count, channels_per_key
215                    );
216                    Err(key)
217                } else {
218                    Ok(o.get().upgrade().unwrap_or_else(|| {
219                        let tracker = Arc::new(Tracker {
220                            key: Some(key),
221                            dropped_keys,
222                        });
223
224                        *o.get_mut() = Arc::downgrade(&tracker);
225                        tracker
226                    }))
227                }
228            }
229        }
230    }
231
232    fn poll_listener(
233        mut self: Pin<&mut Self>,
234        cx: &mut Context<'_>,
235    ) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
236        match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
237            Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
238            None => Poll::Ready(None),
239        }
240    }
241
242    fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
243        match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) {
244            Some(key) => {
245                debug!("All channels dropped for key [{}]", key);
246                self.as_mut().key_counts().remove(&key);
247                self.as_mut().key_counts().compact(0.1);
248                Poll::Ready(())
249            }
250            None => unreachable!("Holding a copy of closed_channels and didn't close it."),
251        }
252    }
253}
254
255impl<S, K, F> Stream for ChannelFilter<S, K, F>
256where
257    S: Stream,
258    K: fmt::Display + Eq + Hash + Clone + Unpin,
259    F: Fn(&S::Item) -> K,
260{
261    type Item = TrackedChannel<S::Item, K>;
262
263    fn poll_next(
264        mut self: Pin<&mut Self>,
265        cx: &mut Context<'_>,
266    ) -> Poll<Option<TrackedChannel<S::Item, K>>> {
267        loop {
268            match (
269                self.as_mut().poll_listener(cx),
270                self.as_mut().poll_closed_channels(cx),
271            ) {
272                (Poll::Ready(Some(Ok(channel))), _) => {
273                    return Poll::Ready(Some(channel));
274                }
275                (Poll::Ready(Some(Err(_))), _) => {
276                    continue;
277                }
278                (_, Poll::Ready(())) => continue,
279                (Poll::Pending, Poll::Pending) => return Poll::Pending,
280                (Poll::Ready(None), Poll::Pending) => {
281                    trace!("Shutting down listener.");
282                    return Poll::Ready(None);
283                }
284            }
285        }
286    }
287}
288
289#[cfg(test)]
290fn ctx() -> Context<'static> {
291    use futures_test::task::noop_waker_ref;
292
293    Context::from_waker(&noop_waker_ref())
294}
295
296#[test]
297fn tracker_drop() {
298    use assert_matches::assert_matches;
299
300    let (tx, mut rx) = mpsc::unbounded();
301    Tracker {
302        key: Some(1),
303        dropped_keys: tx,
304    };
305    assert_matches!(rx.try_next(), Ok(Some(1)));
306}
307
308#[test]
309fn tracked_channel_stream() {
310    use assert_matches::assert_matches;
311    use pin_utils::pin_mut;
312
313    let (chan_tx, chan) = mpsc::unbounded();
314    let (dropped_keys, _) = mpsc::unbounded();
315    let channel = TrackedChannel {
316        inner: chan,
317        tracker: Arc::new(Tracker {
318            key: Some(1),
319            dropped_keys,
320        }),
321    };
322
323    chan_tx.unbounded_send("test").unwrap();
324    pin_mut!(channel);
325    assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test")));
326}
327
328#[test]
329fn tracked_channel_sink() {
330    use assert_matches::assert_matches;
331    use pin_utils::pin_mut;
332
333    let (chan, mut chan_rx) = mpsc::unbounded();
334    let (dropped_keys, _) = mpsc::unbounded();
335    let channel = TrackedChannel {
336        inner: chan,
337        tracker: Arc::new(Tracker {
338            key: Some(1),
339            dropped_keys,
340        }),
341    };
342
343    pin_mut!(channel);
344    assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(())));
345    assert_matches!(channel.as_mut().start_send("test"), Ok(()));
346    assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(())));
347    assert_matches!(chan_rx.try_next(), Ok(Some("test")));
348}
349
350#[test]
351fn channel_filter_increment_channels_for_key() {
352    use assert_matches::assert_matches;
353    use pin_utils::pin_mut;
354
355    struct TestChannel {
356        key: &'static str,
357    }
358    let (_, listener) = mpsc::unbounded();
359    let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
360    pin_mut!(filter);
361    let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
362    assert_eq!(Arc::strong_count(&tracker1), 1);
363    let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
364    assert_eq!(Arc::strong_count(&tracker1), 2);
365    assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
366    drop(tracker2);
367    assert_eq!(Arc::strong_count(&tracker1), 1);
368}
369
370#[test]
371fn channel_filter_handle_new_channel() {
372    use assert_matches::assert_matches;
373    use pin_utils::pin_mut;
374
375    #[derive(Debug)]
376    struct TestChannel {
377        key: &'static str,
378    }
379    let (_, listener) = mpsc::unbounded();
380    let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
381    pin_mut!(filter);
382    let channel1 = filter
383        .as_mut()
384        .handle_new_channel(TestChannel { key: "key" })
385        .unwrap();
386    assert_eq!(Arc::strong_count(&channel1.tracker), 1);
387
388    let channel2 = filter
389        .as_mut()
390        .handle_new_channel(TestChannel { key: "key" })
391        .unwrap();
392    assert_eq!(Arc::strong_count(&channel1.tracker), 2);
393
394    assert_matches!(
395        filter.handle_new_channel(TestChannel { key: "key" }),
396        Err("key")
397    );
398    drop(channel2);
399    assert_eq!(Arc::strong_count(&channel1.tracker), 1);
400}
401
402#[test]
403fn channel_filter_poll_listener() {
404    use assert_matches::assert_matches;
405    use pin_utils::pin_mut;
406
407    #[derive(Debug)]
408    struct TestChannel {
409        key: &'static str,
410    }
411    let (new_channels, listener) = mpsc::unbounded();
412    let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
413    pin_mut!(filter);
414
415    new_channels
416        .unbounded_send(TestChannel { key: "key" })
417        .unwrap();
418    let channel1 =
419        assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
420    assert_eq!(Arc::strong_count(&channel1.tracker), 1);
421
422    new_channels
423        .unbounded_send(TestChannel { key: "key" })
424        .unwrap();
425    let _channel2 =
426        assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
427    assert_eq!(Arc::strong_count(&channel1.tracker), 2);
428
429    new_channels
430        .unbounded_send(TestChannel { key: "key" })
431        .unwrap();
432    let key =
433        assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
434    assert_eq!(key, "key");
435    assert_eq!(Arc::strong_count(&channel1.tracker), 2);
436}
437
438#[test]
439fn channel_filter_poll_closed_channels() {
440    use assert_matches::assert_matches;
441    use pin_utils::pin_mut;
442
443    #[derive(Debug)]
444    struct TestChannel {
445        key: &'static str,
446    }
447    let (new_channels, listener) = mpsc::unbounded();
448    let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
449    pin_mut!(filter);
450
451    new_channels
452        .unbounded_send(TestChannel { key: "key" })
453        .unwrap();
454    let channel =
455        assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
456    assert_eq!(filter.key_counts.len(), 1);
457
458    drop(channel);
459    assert_matches!(
460        filter.as_mut().poll_closed_channels(&mut ctx()),
461        Poll::Ready(())
462    );
463    assert!(filter.key_counts.is_empty());
464}
465
466#[test]
467fn channel_filter_stream() {
468    use assert_matches::assert_matches;
469    use pin_utils::pin_mut;
470
471    #[derive(Debug)]
472    struct TestChannel {
473        key: &'static str,
474    }
475    let (new_channels, listener) = mpsc::unbounded();
476    let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
477    pin_mut!(filter);
478
479    new_channels
480        .unbounded_send(TestChannel { key: "key" })
481        .unwrap();
482    let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c);
483    assert_eq!(filter.key_counts.len(), 1);
484
485    drop(channel);
486    assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending);
487    assert!(filter.key_counts.is_empty());
488}