tor_async_utils/
counting_streams.rs1use std::{
8 pin::{Pin, pin},
9 sync::{
10 Arc,
11 atomic::{AtomicUsize, Ordering},
12 },
13 task::ready,
14 task::{Context, Poll},
15};
16
17use futures::{Stream, sink::Sink, stream::FusedStream};
18use pin_project::pin_project;
19
20#[derive(Clone, Debug)]
22#[pin_project]
23pub struct CountingSink<S> {
24 #[pin]
26 inner: S,
27 count: Arc<AtomicUsize>,
31}
32
33#[derive(Clone, Debug)]
35#[pin_project]
36pub struct CountingStream<S> {
37 #[pin]
39 inner: S,
40 count: Arc<AtomicUsize>,
44}
45
46impl<T, S: Sink<T>> Sink<T> for CountingSink<S> {
47 type Error = S::Error;
48
49 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50 self.project().inner.poll_ready(cx)
51 }
52
53 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
54 let self_ = self.project();
55 let r = self_.inner.start_send(item);
56 if r.is_ok() {
57 self_.count.fetch_add(1, Ordering::Relaxed);
62 }
63 r
64 }
65
66 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67 self.project().inner.poll_flush(cx)
68 }
69
70 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71 self.project().inner.poll_close(cx)
72 }
73}
74
75impl<S: Stream> Stream for CountingStream<S> {
76 type Item = S::Item;
77
78 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
79 let self_ = self.project();
80 let next = ready!(self_.inner.poll_next(cx));
81 if next.is_some() {
82 self_.count.fetch_sub(1, Ordering::Relaxed);
86 }
87 Poll::Ready(next)
88 }
89}
90
91impl<S: FusedStream> FusedStream for CountingStream<S> {
92 fn is_terminated(&self) -> bool {
93 self.inner.is_terminated()
94 }
95}
96
97impl<S> CountingStream<S> {
98 pub fn approx_count(&self) -> usize {
104 self.count.load(Ordering::Relaxed)
105 }
106
107 pub fn inner(&self) -> &S {
112 &self.inner
113 }
114
115 pub fn inner_mut(&mut self) -> &mut S {
120 &mut self.inner
121 }
122}
123
124impl<S> CountingSink<S> {
125 pub fn approx_count(&self) -> usize {
131 self.count.load(Ordering::Relaxed)
132 }
133
134 pub fn inner(&self) -> &S {
139 &self.inner
140 }
141
142 pub fn inner_mut(&mut self) -> &mut S {
147 &mut self.inner
148 }
149}
150
151pub fn channel<T, U>(tx: T, rx: U) -> (CountingSink<T>, CountingStream<U>) {
163 let count = Arc::new(AtomicUsize::new(0));
164 let new_tx = CountingSink {
165 inner: tx,
166 count: Arc::clone(&count),
167 };
168 let new_rx = CountingStream { inner: rx, count };
169 (new_tx, new_rx)
170}
171
172#[cfg(test)]
173mod test {
174 #![allow(clippy::bool_assert_comparison)]
176 #![allow(clippy::clone_on_copy)]
177 #![allow(clippy::dbg_macro)]
178 #![allow(clippy::mixed_attributes_style)]
179 #![allow(clippy::print_stderr)]
180 #![allow(clippy::print_stdout)]
181 #![allow(clippy::single_char_pattern)]
182 #![allow(clippy::unwrap_used)]
183 #![allow(clippy::unchecked_duration_subtraction)]
184 #![allow(clippy::useless_vec)]
185 #![allow(clippy::needless_pass_by_value)]
186 use futures::{SinkExt as _, StreamExt as _};
189
190 #[test]
191 fn send_only_onetask() {
192 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
193 let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
194 let (mut tx, rx) = super::channel(tx, rx);
195 for n in 1..10 {
196 tx.send(n).await.unwrap();
197 assert_eq!(tx.approx_count(), n);
198 assert_eq!(rx.approx_count(), n);
199 }
200 });
201 }
202
203 #[test]
204 fn send_only_twotasks() {
205 tor_rtmock::MockRuntime::test_with_various(|rt| async move {
206 let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
207 let (mut tx, rx) = super::channel(tx, rx);
208
209 let mut tx2 = tx.clone();
210 let j1 = rt.spawn_join("thread1", async move {
211 for n in 1..=10 {
212 tx.send(n).await.unwrap();
213 assert!(tx.approx_count() >= n);
214 }
215 });
216
217 let j2 = rt.spawn_join("thread2", async move {
218 for n in 1..=10 {
219 tx2.send(n).await.unwrap();
220 assert!(tx2.approx_count() >= n);
221 }
222 });
223 j1.await;
224 j2.await;
225 assert_eq!(rx.approx_count(), 20);
226 });
227 }
228
229 #[test]
230 fn send_and_receive() {
231 tor_rtmock::MockRuntime::test_with_various(|rt| async move {
232 let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
233 let (mut tx, mut rx) = super::channel(tx, rx);
234 const MAX: usize = 10000;
235
236 let mut tx2 = tx.clone();
237 let j1 = rt.spawn_join("thread1", async move {
238 for n in 1..=MAX {
239 tx.send(n).await.unwrap();
240 }
241 });
242
243 let j2 = rt.spawn_join("thread2", async move {
244 for n in 1..=MAX {
245 tx2.send(n).await.unwrap();
246 }
247 });
248
249 let j3 = rt.spawn_join("receiver", async move {
250 let mut total = 0;
251 while let Some(x) = rx.next().await {
252 total += x; let count = rx.approx_count();
254 assert!(count <= MAX * 2);
255 }
256 assert_eq!(total, MAX * (MAX + 1)); rx
258 });
259
260 j1.await;
261 j2.await;
262 let rx = j3.await;
263 assert_eq!(rx.approx_count(), 0);
264 });
265 }
266}