split_stream_by/
split_by_buffered.rs

1use std::{
2    pin::Pin,
3    sync::{Arc, Mutex},
4    task::{Poll, Waker},
5};
6
7use crate::ring_buf::RingBuf;
8use futures::Stream;
9use pin_project::pin_project;
10
11#[pin_project]
12pub(crate) struct SplitByBuffered<I, S, P, const N: usize> {
13    buf_true: RingBuf<I, N>,
14    buf_false: RingBuf<I, N>,
15    waker_true: Option<Waker>,
16    waker_false: Option<Waker>,
17    #[pin]
18    stream: S,
19    predicate: P,
20}
21
22impl<I, S, P, const N: usize> SplitByBuffered<I, S, P, N>
23where
24    S: Stream<Item = I>,
25    P: Fn(&I) -> bool,
26{
27    pub(crate) fn new(stream: S, predicate: P) -> Arc<Mutex<Self>> {
28        Arc::new(Mutex::new(Self {
29            buf_false: RingBuf::new(),
30            buf_true: RingBuf::new(),
31            waker_false: None,
32            waker_true: None,
33            stream,
34            predicate,
35        }))
36    }
37
38    fn poll_next_true(
39        self: std::pin::Pin<&mut Self>,
40        cx: &mut std::task::Context<'_>,
41    ) -> std::task::Poll<Option<I>> {
42        let this = self.project();
43        // There should only ever be one waker calling the function
44        if this.waker_true.is_none() {
45            *this.waker_true = Some(cx.waker().clone());
46        }
47        if let Some(item) = this.buf_true.pop_front() {
48            // There was already a value in the buffer. Return that value
49            return Poll::Ready(Some(item));
50        }
51        if this.buf_false.remaining() == 0 {
52            // The other buffer is full, so notify that stream and return pending
53            if let Some(waker) = this.waker_false {
54                waker.wake_by_ref();
55            }
56            return Poll::Pending;
57        }
58        match this.stream.poll_next(cx) {
59            Poll::Ready(Some(item)) => {
60                if (this.predicate)(&item) {
61                    Poll::Ready(Some(item))
62                } else {
63                    // This value is not what we wanted. Store it and notify other partition task if
64                    // it exists. This can't fail because we checked above that the buffer isn't
65                    // full
66                    let _ = this.buf_false.push_back(item);
67                    if let Some(waker) = this.waker_false {
68                        waker.wake_by_ref();
69                    }
70                    Poll::Pending
71                }
72            }
73            Poll::Ready(None) => {
74                // If the underlying stream is finished, the `false` stream also must be
75                // finished, so wake it in case nothing else polls it
76                if let Some(waker) = this.waker_false {
77                    waker.wake_by_ref();
78                }
79                Poll::Ready(None)
80            }
81            Poll::Pending => Poll::Pending,
82        }
83    }
84
85    fn poll_next_false(
86        self: std::pin::Pin<&mut Self>,
87        cx: &mut std::task::Context<'_>,
88    ) -> std::task::Poll<Option<I>> {
89        let this = self.project();
90        // I think there should only ever be one waker calling the function
91        if this.waker_false.is_none() {
92            *this.waker_false = Some(cx.waker().clone());
93        }
94        if let Some(item) = this.buf_false.pop_front() {
95            // There was already a value in the buffer. Return that value
96            return Poll::Ready(Some(item));
97        }
98        if this.buf_true.remaining() == 0 {
99            // The other buffer is full, so notify that stream and return pending
100            if let Some(waker) = this.waker_true {
101                waker.wake_by_ref();
102            }
103            return Poll::Pending;
104        }
105        match this.stream.poll_next(cx) {
106            Poll::Ready(Some(item)) => {
107                if (this.predicate)(&item) {
108                    // This value is not what we wanted. Store it and notify other stream if waker
109                    // it exists. This can't fail because we checked above that the buffer isn't
110                    // full
111                    let _ = this.buf_true.push_back(item);
112                    if let Some(waker) = this.waker_true {
113                        waker.wake_by_ref();
114                    }
115                    Poll::Pending
116                } else {
117                    Poll::Ready(Some(item))
118                }
119            }
120            Poll::Ready(None) => {
121                // If the underlying stream is finished, the `true` stream also must be
122                // finished, so wake it in case nothing else polls it
123                if let Some(waker) = this.waker_true {
124                    waker.wake_by_ref();
125                }
126                Poll::Ready(None)
127            }
128            Poll::Pending => Poll::Pending,
129        }
130    }
131}
132
133/// A struct that implements `Stream` which returns the items where the
134/// predicate returns `true`
135pub struct TrueSplitByBuffered<I, S, P, const N: usize> {
136    stream: Arc<Mutex<SplitByBuffered<I, S, P, N>>>,
137}
138
139impl<I, S, P, const N: usize> TrueSplitByBuffered<I, S, P, N> {
140    pub(crate) fn new(stream: Arc<Mutex<SplitByBuffered<I, S, P, N>>>) -> Self {
141        Self { stream }
142    }
143}
144
145impl<I, S, P, const N: usize> Stream for TrueSplitByBuffered<I, S, P, N>
146where
147    S: Stream<Item = I> + Unpin,
148    P: Fn(&I) -> bool,
149{
150    type Item = I;
151    fn poll_next(
152        self: std::pin::Pin<&mut Self>,
153        cx: &mut std::task::Context<'_>,
154    ) -> std::task::Poll<Option<Self::Item>> {
155        let response = if let Ok(mut guard) = self.stream.try_lock() {
156            SplitByBuffered::poll_next_true(Pin::new(&mut guard), cx)
157        } else {
158            cx.waker().wake_by_ref();
159            Poll::Pending
160        };
161        response
162    }
163}
164
165/// A struct that implements `Stream` which returns the items where the
166/// predicate returns `false`
167pub struct FalseSplitByBuffered<I, S, P, const N: usize> {
168    stream: Arc<Mutex<SplitByBuffered<I, S, P, N>>>,
169}
170
171impl<I, S, P, const N: usize> FalseSplitByBuffered<I, S, P, N> {
172    pub(crate) fn new(stream: Arc<Mutex<SplitByBuffered<I, S, P, N>>>) -> Self {
173        Self { stream }
174    }
175}
176
177impl<I, S, P, const N: usize> Stream for FalseSplitByBuffered<I, S, P, N>
178where
179    S: Stream<Item = I> + Unpin,
180    P: Fn(&I) -> bool,
181{
182    type Item = I;
183    fn poll_next(
184        self: std::pin::Pin<&mut Self>,
185        cx: &mut std::task::Context<'_>,
186    ) -> std::task::Poll<Option<Self::Item>> {
187        let response = if let Ok(mut guard) = self.stream.try_lock() {
188            SplitByBuffered::poll_next_false(Pin::new(&mut guard), cx)
189        } else {
190            cx.waker().wake_by_ref();
191            Poll::Pending
192        };
193        response
194    }
195}