tor_async_utils/
counting_streams.rs

1//! A facility for an MPSC channel that counts the number of outstanding entries on the channel.
2//
3// (Tokio makes this possible by default, but we don't require tokio.  Crossbeam channels also allow
4// this, but they aren't async, and they're MPMC. If a future version of the
5// `futures` crate adds this functionality, we can use that instead. )
6
7use std::{
8    pin::{Pin, pin},
9    sync::{
10        Arc,
11        atomic::{AtomicUsize, Ordering},
12    },
13    task::ready,
14    task::{Context, Poll},
15};
16
17use futures::{Stream, sink::Sink, stream::FusedStream};
18use pin_project::pin_project;
19
20/// A wrapper around an arbitrary [`Sink`], to count the items inserted.
21#[derive(Clone, Debug)]
22#[pin_project]
23pub struct CountingSink<S> {
24    /// The inner sink whose items we're counting.
25    #[pin]
26    inner: S,
27    /// A shared counter for items inserted into the channel
28    ///
29    /// We add 1 every time we enqueue an item.
30    count: Arc<AtomicUsize>,
31}
32
33/// A wrapper around an arbitrary [`Stream`], to count the items inserted.
34#[derive(Clone, Debug)]
35#[pin_project]
36pub struct CountingStream<S> {
37    /// The inner stream whose items we're counting.
38    #[pin]
39    inner: S,
40    /// A shared counter for items inserted into the channel.
41    ///
42    /// We remove 1 every time we dequeue an item.
43    count: Arc<AtomicUsize>,
44}
45
46impl<T, S: Sink<T>> Sink<T> for CountingSink<S> {
47    type Error = S::Error;
48
49    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50        self.project().inner.poll_ready(cx)
51    }
52
53    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
54        let self_ = self.project();
55        let r = self_.inner.start_send(item);
56        if r.is_ok() {
57            // We successfully sent an item, so we increment the counter.
58            //
59            // Using `Relaxed` ensures that the operation is atomic, but does not guarantee its
60            // order with respect to operations on other locations.
61            self_.count.fetch_add(1, Ordering::Relaxed);
62        }
63        r
64    }
65
66    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67        self.project().inner.poll_flush(cx)
68    }
69
70    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71        self.project().inner.poll_close(cx)
72    }
73}
74
75impl<S: Stream> Stream for CountingStream<S> {
76    type Item = S::Item;
77
78    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
79        let self_ = self.project();
80        let next = ready!(self_.inner.poll_next(cx));
81        if next.is_some() {
82            // We got an item, so we'll decrement the counter.
83            //
84            // See note above about "Relaxed" ordering.
85            self_.count.fetch_sub(1, Ordering::Relaxed);
86        }
87        Poll::Ready(next)
88    }
89}
90
91impl<S: FusedStream> FusedStream for CountingStream<S> {
92    fn is_terminated(&self) -> bool {
93        self.inner.is_terminated()
94    }
95}
96
97impl<S> CountingStream<S> {
98    /// Return an approximate count of the number of items currently on this channel.
99    ///
100    /// This count is necessarily approximate because the count can be changed by any of this
101    /// channel's Senders or Receivers between when the caller
102    /// gets the count and when the caller uses the count.
103    pub fn approx_count(&self) -> usize {
104        self.count.load(Ordering::Relaxed)
105    }
106
107    /// Return a reference to the inner stream.
108    ///
109    /// If the stream has interior mutability, the caller must take care
110    /// not to do anything with the stream that would invalidate the current counter.
111    pub fn inner(&self) -> &S {
112        &self.inner
113    }
114
115    /// Return a mutable reference to the inner stream.
116    ///
117    /// If the stream has interior mutability, the caller must take care
118    /// not to do anything with the stream that would invalidate the current counter.
119    pub fn inner_mut(&mut self) -> &mut S {
120        &mut self.inner
121    }
122}
123
124impl<S> CountingSink<S> {
125    /// Return an approximate count of the number of items currently on this channel.
126    ///
127    /// This count is necessarily approximate because the count can be changed by any of this
128    /// channel's Senders or Receivers between when the caller
129    /// gets the count and when the caller uses the count.
130    pub fn approx_count(&self) -> usize {
131        self.count.load(Ordering::Relaxed)
132    }
133
134    /// Return a reference to the inner sink.
135    ///
136    /// If the sink has interior mutability, the caller must take care
137    /// not to do anything with the sink that would invalidate the current counter.
138    pub fn inner(&self) -> &S {
139        &self.inner
140    }
141
142    /// Return a mutable reference to the inner sink.
143    ///
144    /// If the sink has interior mutability, the caller must take care
145    /// not to do anything with the sink that would invalidate the current counter.
146    pub fn inner_mut(&mut self) -> &mut S {
147        &mut self.inner
148    }
149}
150
151/// Wrap a [`Sink`]/[`Stream`] pair into a [`CountingSink`] and [`CountingStream`] pair.
152///
153/// # Correctness
154///
155/// The sink and the stream should match and form a channel:
156/// items sent on the sink should be received from the stream.
157///
158/// There should be no other handles in use for adding or removing items from the channel.
159///
160/// If these requirements aren't met, then the counts returned by the sink and stream
161/// will not be accurate.
162pub fn channel<T, U>(tx: T, rx: U) -> (CountingSink<T>, CountingStream<U>) {
163    let count = Arc::new(AtomicUsize::new(0));
164    let new_tx = CountingSink {
165        inner: tx,
166        count: Arc::clone(&count),
167    };
168    let new_rx = CountingStream { inner: rx, count };
169    (new_tx, new_rx)
170}
171
172#[cfg(test)]
173mod test {
174    // @@ begin test lint list maintained by maint/add_warning @@
175    #![allow(clippy::bool_assert_comparison)]
176    #![allow(clippy::clone_on_copy)]
177    #![allow(clippy::dbg_macro)]
178    #![allow(clippy::mixed_attributes_style)]
179    #![allow(clippy::print_stderr)]
180    #![allow(clippy::print_stdout)]
181    #![allow(clippy::single_char_pattern)]
182    #![allow(clippy::unwrap_used)]
183    #![allow(clippy::unchecked_duration_subtraction)]
184    #![allow(clippy::useless_vec)]
185    #![allow(clippy::needless_pass_by_value)]
186    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
187
188    use futures::{SinkExt as _, StreamExt as _};
189
190    #[test]
191    fn send_only_onetask() {
192        tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
193            let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
194            let (mut tx, rx) = super::channel(tx, rx);
195            for n in 1..10 {
196                tx.send(n).await.unwrap();
197                assert_eq!(tx.approx_count(), n);
198                assert_eq!(rx.approx_count(), n);
199            }
200        });
201    }
202
203    #[test]
204    fn send_only_twotasks() {
205        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
206            let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
207            let (mut tx, rx) = super::channel(tx, rx);
208
209            let mut tx2 = tx.clone();
210            let j1 = rt.spawn_join("thread1", async move {
211                for n in 1..=10 {
212                    tx.send(n).await.unwrap();
213                    assert!(tx.approx_count() >= n);
214                }
215            });
216
217            let j2 = rt.spawn_join("thread2", async move {
218                for n in 1..=10 {
219                    tx2.send(n).await.unwrap();
220                    assert!(tx2.approx_count() >= n);
221                }
222            });
223            j1.await;
224            j2.await;
225            assert_eq!(rx.approx_count(), 20);
226        });
227    }
228
229    #[test]
230    fn send_and_receive() {
231        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
232            let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
233            let (mut tx, mut rx) = super::channel(tx, rx);
234            const MAX: usize = 10000;
235
236            let mut tx2 = tx.clone();
237            let j1 = rt.spawn_join("thread1", async move {
238                for n in 1..=MAX {
239                    tx.send(n).await.unwrap();
240                }
241            });
242
243            let j2 = rt.spawn_join("thread2", async move {
244                for n in 1..=MAX {
245                    tx2.send(n).await.unwrap();
246                }
247            });
248
249            let j3 = rt.spawn_join("receiver", async move {
250                let mut total = 0;
251                while let Some(x) = rx.next().await {
252                    total += x; // spot check
253                    let count = rx.approx_count();
254                    assert!(count <= MAX * 2);
255                }
256                assert_eq!(total, MAX * (MAX + 1)); // two senders, so no "/2".
257                rx
258            });
259
260            j1.await;
261            j2.await;
262            let rx = j3.await;
263            assert_eq!(rx.approx_count(), 0);
264        });
265    }
266}