1use crate::{
8 server::{self, Channel},
9 util::Compact,
10};
11use fnv::FnvHashMap;
12use futures::{
13 channel::mpsc,
14 future::AbortRegistration,
15 prelude::*,
16 ready,
17 stream::Fuse,
18 task::{Context, Poll},
19};
20use log::{debug, info, trace};
21use pin_utils::{unsafe_pinned, unsafe_unpinned};
22use std::sync::{Arc, Weak};
23use std::{
24 collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
25};
26
27#[derive(Debug)]
29pub struct ChannelFilter<S, K, F>
30where
31 K: Eq + Hash,
32{
33 listener: Fuse<S>,
34 channels_per_key: u32,
35 dropped_keys: mpsc::UnboundedReceiver<K>,
36 dropped_keys_tx: mpsc::UnboundedSender<K>,
37 key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
38 keymaker: F,
39}
40
41#[derive(Debug)]
43pub struct TrackedChannel<C, K> {
44 inner: C,
45 tracker: Arc<Tracker<K>>,
46}
47
48impl<C, K> TrackedChannel<C, K> {
49 unsafe_pinned!(inner: C);
50}
51
52#[derive(Debug)]
53struct Tracker<K> {
54 key: Option<K>,
55 dropped_keys: mpsc::UnboundedSender<K>,
56}
57
58impl<K> Drop for Tracker<K> {
59 fn drop(&mut self) {
60 let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
62 }
63}
64
65impl<C, K> Stream for TrackedChannel<C, K>
66where
67 C: Stream,
68{
69 type Item = <C as Stream>::Item;
70
71 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
72 self.channel().poll_next(cx)
73 }
74}
75
76impl<C, I, K> Sink<I> for TrackedChannel<C, K>
77where
78 C: Sink<I>,
79{
80 type Error = C::Error;
81
82 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
83 self.channel().poll_ready(cx)
84 }
85
86 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
87 self.channel().start_send(item)
88 }
89
90 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
91 self.channel().poll_flush(cx)
92 }
93
94 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
95 self.channel().poll_close(cx)
96 }
97}
98
99impl<C, K> AsRef<C> for TrackedChannel<C, K> {
100 fn as_ref(&self) -> &C {
101 &self.inner
102 }
103}
104
105impl<C, K> Channel for TrackedChannel<C, K>
106where
107 C: Channel,
108{
109 type Req = C::Req;
110 type Resp = C::Resp;
111
112 fn config(&self) -> &server::Config {
113 self.inner.config()
114 }
115
116 fn in_flight_requests(self: Pin<&mut Self>) -> usize {
117 self.inner().in_flight_requests()
118 }
119
120 fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
121 self.inner().start_request(request_id)
122 }
123}
124
125impl<C, K> TrackedChannel<C, K> {
126 pub fn get_ref(&self) -> &C {
128 &self.inner
129 }
130
131 fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
133 self.inner()
134 }
135}
136
137impl<S, K, F> ChannelFilter<S, K, F>
138where
139 K: fmt::Display + Eq + Hash + Clone,
140{
141 unsafe_pinned!(listener: Fuse<S>);
142 unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
143 unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
144 unsafe_unpinned!(key_counts: FnvHashMap<K, Weak<Tracker<K>>>);
145 unsafe_unpinned!(channels_per_key: u32);
146 unsafe_unpinned!(keymaker: F);
147}
148
149impl<S, K, F> ChannelFilter<S, K, F>
150where
151 K: Eq + Hash,
152 S: Stream,
153 F: Fn(&S::Item) -> K,
154{
155 pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
157 let (dropped_keys_tx, dropped_keys) = mpsc::unbounded();
158 ChannelFilter {
159 listener: listener.fuse(),
160 channels_per_key,
161 dropped_keys,
162 dropped_keys_tx,
163 key_counts: FnvHashMap::default(),
164 keymaker,
165 }
166 }
167}
168
169impl<S, K, F> ChannelFilter<S, K, F>
170where
171 S: Stream,
172 K: fmt::Display + Eq + Hash + Clone + Unpin,
173 F: Fn(&S::Item) -> K,
174{
175 fn handle_new_channel(
176 mut self: Pin<&mut Self>,
177 stream: S::Item,
178 ) -> Result<TrackedChannel<S::Item, K>, K> {
179 let key = self.as_mut().keymaker()(&stream);
180 let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
181
182 trace!(
183 "[{}] Opening channel ({}/{}) channels for key.",
184 key,
185 Arc::strong_count(&tracker),
186 self.as_mut().channels_per_key()
187 );
188
189 Ok(TrackedChannel {
190 tracker,
191 inner: stream,
192 })
193 }
194
195 fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
196 let channels_per_key = self.channels_per_key;
197 let dropped_keys = self.dropped_keys_tx.clone();
198 let key_counts = &mut self.as_mut().key_counts();
199 match key_counts.entry(key.clone()) {
200 Entry::Vacant(vacant) => {
201 let tracker = Arc::new(Tracker {
202 key: Some(key),
203 dropped_keys,
204 });
205
206 vacant.insert(Arc::downgrade(&tracker));
207 Ok(tracker)
208 }
209 Entry::Occupied(mut o) => {
210 let count = o.get().strong_count();
211 if count >= channels_per_key.try_into().unwrap() {
212 info!(
213 "[{}] Opened max channels from key ({}/{}).",
214 key, count, channels_per_key
215 );
216 Err(key)
217 } else {
218 Ok(o.get().upgrade().unwrap_or_else(|| {
219 let tracker = Arc::new(Tracker {
220 key: Some(key),
221 dropped_keys,
222 });
223
224 *o.get_mut() = Arc::downgrade(&tracker);
225 tracker
226 }))
227 }
228 }
229 }
230 }
231
232 fn poll_listener(
233 mut self: Pin<&mut Self>,
234 cx: &mut Context<'_>,
235 ) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
236 match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
237 Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
238 None => Poll::Ready(None),
239 }
240 }
241
242 fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
243 match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) {
244 Some(key) => {
245 debug!("All channels dropped for key [{}]", key);
246 self.as_mut().key_counts().remove(&key);
247 self.as_mut().key_counts().compact(0.1);
248 Poll::Ready(())
249 }
250 None => unreachable!("Holding a copy of closed_channels and didn't close it."),
251 }
252 }
253}
254
255impl<S, K, F> Stream for ChannelFilter<S, K, F>
256where
257 S: Stream,
258 K: fmt::Display + Eq + Hash + Clone + Unpin,
259 F: Fn(&S::Item) -> K,
260{
261 type Item = TrackedChannel<S::Item, K>;
262
263 fn poll_next(
264 mut self: Pin<&mut Self>,
265 cx: &mut Context<'_>,
266 ) -> Poll<Option<TrackedChannel<S::Item, K>>> {
267 loop {
268 match (
269 self.as_mut().poll_listener(cx),
270 self.as_mut().poll_closed_channels(cx),
271 ) {
272 (Poll::Ready(Some(Ok(channel))), _) => {
273 return Poll::Ready(Some(channel));
274 }
275 (Poll::Ready(Some(Err(_))), _) => {
276 continue;
277 }
278 (_, Poll::Ready(())) => continue,
279 (Poll::Pending, Poll::Pending) => return Poll::Pending,
280 (Poll::Ready(None), Poll::Pending) => {
281 trace!("Shutting down listener.");
282 return Poll::Ready(None);
283 }
284 }
285 }
286 }
287}
288
289#[cfg(test)]
290fn ctx() -> Context<'static> {
291 use futures_test::task::noop_waker_ref;
292
293 Context::from_waker(&noop_waker_ref())
294}
295
296#[test]
297fn tracker_drop() {
298 use assert_matches::assert_matches;
299
300 let (tx, mut rx) = mpsc::unbounded();
301 Tracker {
302 key: Some(1),
303 dropped_keys: tx,
304 };
305 assert_matches!(rx.try_next(), Ok(Some(1)));
306}
307
308#[test]
309fn tracked_channel_stream() {
310 use assert_matches::assert_matches;
311 use pin_utils::pin_mut;
312
313 let (chan_tx, chan) = mpsc::unbounded();
314 let (dropped_keys, _) = mpsc::unbounded();
315 let channel = TrackedChannel {
316 inner: chan,
317 tracker: Arc::new(Tracker {
318 key: Some(1),
319 dropped_keys,
320 }),
321 };
322
323 chan_tx.unbounded_send("test").unwrap();
324 pin_mut!(channel);
325 assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test")));
326}
327
328#[test]
329fn tracked_channel_sink() {
330 use assert_matches::assert_matches;
331 use pin_utils::pin_mut;
332
333 let (chan, mut chan_rx) = mpsc::unbounded();
334 let (dropped_keys, _) = mpsc::unbounded();
335 let channel = TrackedChannel {
336 inner: chan,
337 tracker: Arc::new(Tracker {
338 key: Some(1),
339 dropped_keys,
340 }),
341 };
342
343 pin_mut!(channel);
344 assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(())));
345 assert_matches!(channel.as_mut().start_send("test"), Ok(()));
346 assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(())));
347 assert_matches!(chan_rx.try_next(), Ok(Some("test")));
348}
349
350#[test]
351fn channel_filter_increment_channels_for_key() {
352 use assert_matches::assert_matches;
353 use pin_utils::pin_mut;
354
355 struct TestChannel {
356 key: &'static str,
357 }
358 let (_, listener) = mpsc::unbounded();
359 let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
360 pin_mut!(filter);
361 let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
362 assert_eq!(Arc::strong_count(&tracker1), 1);
363 let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
364 assert_eq!(Arc::strong_count(&tracker1), 2);
365 assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
366 drop(tracker2);
367 assert_eq!(Arc::strong_count(&tracker1), 1);
368}
369
370#[test]
371fn channel_filter_handle_new_channel() {
372 use assert_matches::assert_matches;
373 use pin_utils::pin_mut;
374
375 #[derive(Debug)]
376 struct TestChannel {
377 key: &'static str,
378 }
379 let (_, listener) = mpsc::unbounded();
380 let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
381 pin_mut!(filter);
382 let channel1 = filter
383 .as_mut()
384 .handle_new_channel(TestChannel { key: "key" })
385 .unwrap();
386 assert_eq!(Arc::strong_count(&channel1.tracker), 1);
387
388 let channel2 = filter
389 .as_mut()
390 .handle_new_channel(TestChannel { key: "key" })
391 .unwrap();
392 assert_eq!(Arc::strong_count(&channel1.tracker), 2);
393
394 assert_matches!(
395 filter.handle_new_channel(TestChannel { key: "key" }),
396 Err("key")
397 );
398 drop(channel2);
399 assert_eq!(Arc::strong_count(&channel1.tracker), 1);
400}
401
402#[test]
403fn channel_filter_poll_listener() {
404 use assert_matches::assert_matches;
405 use pin_utils::pin_mut;
406
407 #[derive(Debug)]
408 struct TestChannel {
409 key: &'static str,
410 }
411 let (new_channels, listener) = mpsc::unbounded();
412 let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
413 pin_mut!(filter);
414
415 new_channels
416 .unbounded_send(TestChannel { key: "key" })
417 .unwrap();
418 let channel1 =
419 assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
420 assert_eq!(Arc::strong_count(&channel1.tracker), 1);
421
422 new_channels
423 .unbounded_send(TestChannel { key: "key" })
424 .unwrap();
425 let _channel2 =
426 assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
427 assert_eq!(Arc::strong_count(&channel1.tracker), 2);
428
429 new_channels
430 .unbounded_send(TestChannel { key: "key" })
431 .unwrap();
432 let key =
433 assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
434 assert_eq!(key, "key");
435 assert_eq!(Arc::strong_count(&channel1.tracker), 2);
436}
437
438#[test]
439fn channel_filter_poll_closed_channels() {
440 use assert_matches::assert_matches;
441 use pin_utils::pin_mut;
442
443 #[derive(Debug)]
444 struct TestChannel {
445 key: &'static str,
446 }
447 let (new_channels, listener) = mpsc::unbounded();
448 let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
449 pin_mut!(filter);
450
451 new_channels
452 .unbounded_send(TestChannel { key: "key" })
453 .unwrap();
454 let channel =
455 assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
456 assert_eq!(filter.key_counts.len(), 1);
457
458 drop(channel);
459 assert_matches!(
460 filter.as_mut().poll_closed_channels(&mut ctx()),
461 Poll::Ready(())
462 );
463 assert!(filter.key_counts.is_empty());
464}
465
466#[test]
467fn channel_filter_stream() {
468 use assert_matches::assert_matches;
469 use pin_utils::pin_mut;
470
471 #[derive(Debug)]
472 struct TestChannel {
473 key: &'static str,
474 }
475 let (new_channels, listener) = mpsc::unbounded();
476 let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
477 pin_mut!(filter);
478
479 new_channels
480 .unbounded_send(TestChannel { key: "key" })
481 .unwrap();
482 let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c);
483 assert_eq!(filter.key_counts.len(), 1);
484
485 drop(channel);
486 assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending);
487 assert!(filter.key_counts.is_empty());
488}