tor_async_utils/
stream_peek.rs

1//! [`StreamUnobtrusivePeeker`]
2//!
3//! The memory tracker needs a way to look at the next item of a stream
4//! (if there is one, or there can immediately be one),
5//! *without* getting involved with the async tasks.
6
7use educe::Educe;
8use futures::Stream;
9use futures::stream::FusedStream;
10use pin_project::pin_project;
11
12use crate::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
13
14use std::fmt::Debug;
15use std::future::Future;
16use std::pin::Pin;
17use std::task::{Context, Poll, Poll::*, Waker};
18
19/// Wraps [`Stream`] and provides `\[poll_]peek` and `unobtrusive_peek`
20///
21/// [`unobtrusive_peek`](StreamUnobtrusivePeeker::unobtrusive_peek)
22/// is callable in sync contexts, outside the reading task.
23///
24/// Like [`futures::stream::Peekable`],
25/// this has an async `peek` method, and `poll_peek`,
26/// for use from the task that is also reading (via the [`Stream`] impl).
27/// But, that type doesn't have `unobtrusive_peek`.
28///
29/// One way to conceptualise this is that `StreamUnobtrusivePeeker` is dual-ported:
30/// the two sets of APIs, while provided on the same type,
31/// are typically called from different contexts.
32//
33// It wasn't particularly easy to think of a good name for this type.
34// We intend, probably:
35//     struct StreamUnobtrusivePeeker
36//     trait StreamUnobtrusivePeekable
37//     trait StreamPeekable (impl for StreamUnobtrusivePeeker and futures::stream::Peekable)
38//
39// Searching a thesaurus produced these suggested words:
40//     unobtrusive subtle discreet inconspicuous cautious furtive
41// Asking in MR review also suggested
42//     quick
43//
44// It's awkward because "peek" already has significant connotations of not disturbing things.
45// That's why it was used in Iterator::peek.
46//
47// But when we translate this into async context,
48// we have the poll_peek method on futures::stream::Peekable,
49// which doesn't remove items from the stream,
50// but *does* *wait* for items and therefore engages with the async context,
51// and therefore involves *mutating* the Peekable (to store the new waker).
52//
53// Now we end up needing a word for an *even less disturbing* kind of interaction.
54//
55// `quick` (and synonyms) isn't quite right either because it's not necessarily faster,
56// and certainly not more performant.
57#[derive(Debug)]
58#[pin_project(project = PeekerProj)]
59pub struct StreamUnobtrusivePeeker<S: Stream> {
60    /// An item that we have peeked.
61    ///
62    /// (If we peeked EOF, that's represented by `None` in inner.)
63    buffered: Option<S::Item>,
64
65    /// The `Waker` from the last time we were polled and returned `Pending`
66    ///
67    /// "polled" includes any of our `poll_` methods
68    /// but *not* `unobtrusive_peek`.
69    ///
70    /// `None` if we haven't been polled, or the last poll returned `Ready`.
71    poll_waker: Option<Waker>,
72
73    /// The inner stream
74    ///
75    /// `None if it has yielded `None` meaning EOF.  We don't require S: FusedStream.
76    #[pin]
77    inner: Option<S>,
78}
79
80impl<S: Stream> StreamUnobtrusivePeeker<S> {
81    /// Create a new `StreamUnobtrusivePeeker` from a `Stream`
82    pub fn new(inner: S) -> Self {
83        StreamUnobtrusivePeeker {
84            buffered: None,
85            poll_waker: None,
86            inner: Some(inner),
87        }
88    }
89}
90
91impl<S: Stream> UnobtrusivePeekableStream for StreamUnobtrusivePeeker<S> {
92    fn unobtrusive_peek_mut<'s>(mut self: Pin<&'s mut Self>) -> Option<&'s mut S::Item> {
93        #[allow(clippy::question_mark)] // We use explicit control flow here for clarity
94        if self.as_mut().project().buffered.is_none() {
95            // We don't have a buffered item, but the stream may have an item available.
96            // We must poll it to find out.
97            //
98            // We need to pass a Context to poll_next.
99            // inner may store this context, replacing one provided via poll_*.
100            //
101            // Despite that, we need to make sure that wakeups will happen as expected.
102            // To achieve this we have retained a copy of the caller's Waker.
103            //
104            // When a future or stream returns Pending, it proposes to wake `waker`
105            // when it wants to be polled again.
106            //
107            // We uphold that promise by
108            // - only returning Pending from our poll methods if inner also returned Pending
109            // - when one of our poll methods returns Pending, saving the caller-supplied
110            //   waker, so that we can make the intermediate poll call here.
111            //
112            // If the inner poll returns Ready, inner no longer guarantees to wake anyone.
113            // In principle, if our user is waiting (we returned Pending),
114            // then inner ought to have called `wake` on the caller's `Waker`.
115            // But I don't think we can guarantee that an executor won't defer a wakeup,
116            // and respond to a dropped Waker by cancelling that wakeup;
117            // or to put it another way, the wakeup might be "in flight" on entry,
118            // but the call to inner's poll_next returning Ready
119            // might somehow "cancel" the wakeup.
120            //
121            // So just to be sure, if we get a Ready here, we wake the stored waker.
122
123            let mut self_ = self.as_mut().project();
124
125            let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
126                return None;
127            };
128
129            let waker = if let Some(waker) = self_.poll_waker.as_ref() {
130                waker
131            } else {
132                Waker::noop()
133            };
134
135            match inner.poll_next(&mut Context::from_waker(waker)) {
136                Pending => {}
137                Ready(item_or_eof) => {
138                    if let Some(waker) = self_.poll_waker.take() {
139                        waker.wake();
140                    }
141                    match item_or_eof {
142                        None => self_.inner.set(None),
143                        Some(item) => *self_.buffered = Some(item),
144                    }
145                }
146            };
147        }
148
149        self.project().buffered.as_mut()
150    }
151}
152
153impl<S: Stream> PeekableStream for StreamUnobtrusivePeeker<S> {
154    fn poll_peek<'s>(self: Pin<&'s mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
155        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_ref())
156    }
157
158    fn poll_peek_mut<'s>(
159        self: Pin<&'s mut Self>,
160        cx: &mut Context<'_>,
161    ) -> Poll<Option<&'s mut S::Item>> {
162        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_mut())
163    }
164}
165
166impl<S: Stream> StreamUnobtrusivePeeker<S> {
167    /// Implementation of `poll_{peek,next}`
168    ///
169    /// This takes care of
170    ///   * examining the state of our buffer, and polling inner if needed
171    ///   * ensuring that we store a waker, if needed
172    ///   * dealing with some borrowck awkwardness
173    ///
174    /// The `Ready` value is always calculated from `buffer`.
175    /// `return_value_obtainer` is called only if we are going to return `Ready`.
176    /// It's given `buffer` and should either:
177    ///   * [`take`](Option::take) the contained value (for `poll_next`)
178    ///   * return a reference using [`Option::as_ref`] (for `poll_peek`)
179    fn impl_poll_next_or_peek<'s, R: 's>(
180        self: Pin<&'s mut Self>,
181        cx: &mut Context<'_>,
182        return_value_obtainer: impl FnOnce(&'s mut Option<S::Item>) -> Option<R>,
183    ) -> Poll<Option<R>> {
184        let mut self_ = self.project();
185        let r = Self::next_or_peek_inner(&mut self_, cx);
186        let r = r.map(|()| return_value_obtainer(self_.buffered));
187        Self::return_from_poll(self_.poll_waker, cx, r)
188    }
189
190    /// Try to populate `buffer`, and calculate if we're `Ready`
191    ///
192    /// Returns `Ready` iff `poll_next` or `poll_peek` should return `Ready`.
193    /// The actual `Ready` value (an `Option`) will be calculated later.
194    fn next_or_peek_inner(self_: &mut PeekerProj<S>, cx: &mut Context<'_>) -> Poll<()> {
195        if let Some(_item) = self_.buffered.as_ref() {
196            // `return_value_obtainer` will find `Some` in `buffered`;
197            // overall, we'll return `Ready(Some(..))`.
198            return Ready(());
199        }
200        let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
201            // `return_value_obtainer` will find `None` in `buffered`;
202            // overall, we'll return `Ready(None)`, ie EOF.
203            return Ready(());
204        };
205        match inner.poll_next(cx) {
206            Ready(None) => {
207                self_.inner.set(None);
208                // `buffered` is `None`, still.
209                // overall, we'll return `Ready(None)`, ie EOF.
210                Ready(())
211            }
212            Ready(Some(item)) => {
213                *self_.buffered = Some(item);
214                // return_value_obtainer` will find `Some` in `buffered`
215                Ready(())
216            }
217            Pending => {
218                // `return_value_obtainer` won't be called.
219                // overall, we'll return Pending
220                Pending
221            }
222        }
223    }
224
225    /// Wait for an item to be ready, and then inspect it
226    ///
227    /// Equivalent to [`futures::stream::Peekable::peek`].
228    ///
229    /// # Tasks, waking, and calling context
230    ///
231    /// This should be called by the task that is reading from the stream.
232    /// If it is called by another task, the reading task would miss notifications.
233    //
234    // This ^ docs section is triplicated for poll_peek, poll_peek_mut, and peek
235    //
236    // TODO this should be a method on the `PeekableStream` trait? Or a
237    // `PeekableStreamExt` trait?
238    // TODO should there be peek_mut ?
239    #[allow(dead_code)] // TODO remove this allow if and when we make this module public
240    pub fn peek(self: Pin<&mut Self>) -> PeekFuture<Self> {
241        PeekFuture { peeker: Some(self) }
242    }
243
244    /// Return from a `poll_*` function, setting the stored waker appropriately
245    ///
246    /// Our `poll` functions always use this.
247    /// The rule is that if a future returns `Pending`, it has stored the waker.
248    fn return_from_poll<R>(
249        poll_waker: &mut Option<Waker>,
250        cx: &mut Context<'_>,
251        r: Poll<R>,
252    ) -> Poll<R> {
253        *poll_waker = match &r {
254            Ready(_) => {
255                // No need to wake this task up any more.
256                None
257            }
258            Pending => {
259                // try_peek must use the same waker to poll later
260                Some(cx.waker().clone())
261            }
262        };
263        r
264    }
265
266    /// Obtain a raw reference to the inner stream
267    ///
268    /// ### Correctness!
269    ///
270    /// This method must be used with care!
271    /// Whatever you do mustn't interfere with polling and peeking.
272    /// Careless use can result in wrong behaviour including deadlocks.
273    pub fn as_raw_inner_pin_mut<'s>(self: Pin<&'s mut Self>) -> Option<Pin<&'s mut S>> {
274        self.project().inner.as_pin_mut()
275    }
276}
277
278impl<S: Stream> Stream for StreamUnobtrusivePeeker<S> {
279    type Item = S::Item;
280
281    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
282        self.impl_poll_next_or_peek(cx, |buffered| buffered.take())
283    }
284
285    fn size_hint(&self) -> (usize, Option<usize>) {
286        let buf = self.buffered.iter().count();
287        let (imin, imax) = match &self.inner {
288            Some(inner) => inner.size_hint(),
289            None => (0, Some(0)),
290        };
291        (imin + buf, imax.and_then(|imap| imap.checked_add(buf)))
292    }
293}
294
295impl<S: Stream> FusedStream for StreamUnobtrusivePeeker<S> {
296    fn is_terminated(&self) -> bool {
297        self.buffered.is_none() && self.inner.is_none()
298    }
299}
300
301/// Future from [`StreamUnobtrusivePeeker::peek`]
302// TODO: Move to tor_async_utils::peekable_stream.
303#[derive(Educe)]
304#[educe(Debug(bound("S: Debug")))]
305#[must_use = "peek() return a Future, which does nothing unless awaited"]
306pub struct PeekFuture<'s, S> {
307    /// The underlying stream.
308    ///
309    /// `Some` until we have returned `Ready`, then `None`.
310    /// See comment in `poll`.
311    peeker: Option<Pin<&'s mut S>>,
312}
313
314impl<'s, S: PeekableStream> PeekFuture<'s, S> {
315    /// Create a new `PeekFuture`.
316    // TODO: replace with a trait method.
317    pub fn new(stream: Pin<&'s mut S>) -> Self {
318        Self {
319            peeker: Some(stream),
320        }
321    }
322}
323
324impl<'s, S: PeekableStream> Future for PeekFuture<'s, S> {
325    type Output = Option<&'s S::Item>;
326    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
327        let self_ = self.get_mut();
328        let peeker = self_
329            .peeker
330            .as_mut()
331            .expect("PeekFuture polled after Ready");
332        match peeker.as_mut().poll_peek(cx) {
333            Pending => return Pending,
334            Ready(_y) => {
335                // Ideally we would have returned `y` here, but it's borrowed from PeekFuture
336                // not from the original StreamUnobtrusivePeeker, and there's no way
337                // to get a value with the right lifetime.  (In non-async code,
338                // this is usually handled by the special magic for reborrowing &mut.)
339                //
340                // So we must redo the poll, but this time consuming `peeker`,
341                // which gets us the right lifetime.  That's why it has to be `Option`.
342                // Because we own &mut ... Self, we know that repeating the poll
343                // gives the same answer.
344            }
345        }
346        let peeker = self_.peeker.take().expect("it was Some before!");
347        let r = peeker.poll_peek(cx);
348        assert!(r.is_ready(), "it was Ready before!");
349        r
350    }
351}
352
353#[cfg(test)]
354mod test {
355    // @@ begin test lint list maintained by maint/add_warning @@
356    #![allow(clippy::bool_assert_comparison)]
357    #![allow(clippy::clone_on_copy)]
358    #![allow(clippy::dbg_macro)]
359    #![allow(clippy::mixed_attributes_style)]
360    #![allow(clippy::print_stderr)]
361    #![allow(clippy::print_stdout)]
362    #![allow(clippy::single_char_pattern)]
363    #![allow(clippy::unwrap_used)]
364    #![allow(clippy::unchecked_duration_subtraction)]
365    #![allow(clippy::useless_vec)]
366    #![allow(clippy::needless_pass_by_value)]
367    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
368
369    use super::*;
370    use futures::channel::mpsc;
371    use futures::{SinkExt as _, StreamExt as _};
372    use std::pin::pin;
373    use std::sync::{Arc, Mutex};
374    use std::time::Duration;
375    use tor_rtcompat::SleepProvider as _;
376    use tor_rtmock::MockRuntime;
377
378    fn ms(ms: u64) -> Duration {
379        Duration::from_millis(ms)
380    }
381
382    #[test]
383    fn wakeups() {
384        MockRuntime::test_with_various(|rt| async move {
385            let (mut tx, rx) = mpsc::unbounded();
386            let ended = Arc::new(Mutex::new(false));
387
388            rt.spawn_identified("rxr", {
389                let rt = rt.clone();
390                let ended = ended.clone();
391
392                async move {
393                    let rx = StreamUnobtrusivePeeker::new(rx);
394                    let mut rx = pin!(rx);
395
396                    let mut next = 0;
397                    loop {
398                        rt.sleep(ms(50)).await;
399                        eprintln!("rx peek... ");
400                        let peeked = rx.as_mut().unobtrusive_peek_mut();
401                        eprintln!("rx peeked {peeked:?}");
402
403                        if let Some(peeked) = peeked {
404                            assert_eq!(*peeked, next);
405                        }
406
407                        rt.sleep(ms(50)).await;
408                        eprintln!("rx next... ");
409                        let eaten = rx.next().await;
410                        eprintln!("rx eaten {eaten:?}");
411                        if let Some(eaten) = eaten {
412                            assert_eq!(eaten, next);
413                            next += 1;
414                        } else {
415                            break;
416                        }
417                    }
418
419                    *ended.lock().unwrap() = true;
420                    eprintln!("rx ended");
421                }
422            });
423
424            rt.spawn_identified("tx", {
425                let rt = rt.clone();
426
427                async move {
428                    let mut numbers = 0..;
429                    for wait in [125, 1, 125, 45, 1, 1, 1, 1000, 20, 1, 125, 125, 1000] {
430                        eprintln!("tx sleep {wait}");
431                        rt.sleep(ms(wait)).await;
432                        let num = numbers.next().unwrap();
433                        eprintln!("tx sending {num}");
434                        tx.send(num).await.unwrap();
435                    }
436
437                    // This schedule arranges that, when we send EOF, the rx task
438                    // has *peeked* rather than *polled* most recently,
439                    // demonstrating that we can wake up the subsequent poll on EOF too.
440                    eprintln!("tx final #1");
441                    rt.sleep(ms(75)).await;
442                    eprintln!("tx EOF");
443                    drop(tx);
444                    eprintln!("tx final #2");
445                    rt.sleep(ms(10)).await;
446                    assert!(!*ended.lock().unwrap());
447                    eprintln!("tx final #3");
448                    rt.sleep(ms(50)).await;
449                    eprintln!("tx final #4");
450                    assert!(*ended.lock().unwrap());
451                }
452            });
453
454            rt.advance_until_stalled().await;
455        });
456    }
457
458    #[test]
459    fn poll_peek_paths() {
460        MockRuntime::test_with_various(|rt| async move {
461            let (mut tx, rx) = mpsc::unbounded();
462            let ended = Arc::new(Mutex::new(false));
463
464            rt.spawn_identified("rxr", {
465                let rt = rt.clone();
466                let ended = ended.clone();
467
468                async move {
469                    let rx = StreamUnobtrusivePeeker::new(rx);
470                    let mut rx = pin!(rx);
471
472                    while let Some(peeked) = rx.as_mut().peek().await.copied() {
473                        eprintln!("rx peeked {peeked}");
474                        let eaten = rx.next().await.unwrap();
475                        eprintln!("rx eaten  {eaten}");
476                        assert_eq!(peeked, eaten);
477                        rt.sleep(ms(10)).await;
478                        eprintln!("rx slept, peeking");
479                    }
480                    *ended.lock().unwrap() = true;
481                    eprintln!("rx ended");
482                }
483            });
484
485            rt.spawn_identified("tx", {
486                let rt = rt.clone();
487
488                async move {
489                    let mut numbers = 0..;
490
491                    // macro because we don't have proper async closures
492                    macro_rules! send { {} => {
493                        let num = numbers.next().unwrap();
494                        eprintln!("tx send   {num}");
495                        tx.send(num).await.unwrap();
496                    } }
497
498                    eprintln!("tx starting");
499                    rt.sleep(ms(100)).await;
500                    send!();
501                    rt.sleep(ms(100)).await;
502                    send!();
503                    send!();
504                    rt.sleep(ms(100)).await;
505                    eprintln!("tx dropping");
506                    drop(tx);
507                    rt.sleep(ms(5)).await;
508                    eprintln!("tx ending");
509                    assert!(*ended.lock().unwrap());
510                }
511            });
512
513            rt.advance_until_stalled().await;
514        });
515    }
516}