split_stream_by/
split_by.rs

1use std::{
2    pin::Pin,
3    sync::{Arc, Mutex},
4    task::{Poll, Waker},
5};
6
7use futures::Stream;
8use pin_project::pin_project;
9
10#[pin_project]
11pub(crate) struct SplitBy<I, S, P> {
12    buf_true: Option<I>,
13    buf_false: Option<I>,
14    waker_true: Option<Waker>,
15    waker_false: Option<Waker>,
16    #[pin]
17    stream: S,
18    predicate: P,
19}
20
21impl<I, S, P> SplitBy<I, S, P>
22where
23    S: Stream<Item = I>,
24    P: Fn(&I) -> bool,
25{
26    pub(crate) fn new(stream: S, predicate: P) -> Arc<Mutex<Self>> {
27        Arc::new(Mutex::new(Self {
28            buf_false: None,
29            buf_true: None,
30            waker_false: None,
31            waker_true: None,
32            stream,
33            predicate,
34        }))
35    }
36
37    fn poll_next_true(
38        self: std::pin::Pin<&mut Self>,
39        cx: &mut std::task::Context<'_>,
40    ) -> std::task::Poll<Option<I>> {
41        let this = self.project();
42        // There should only ever be one waker calling the function
43        if this.waker_true.is_none() {
44            *this.waker_true = Some(cx.waker().clone());
45        }
46        if let Some(item) = this.buf_true.take() {
47            // There was already a value in the buffer. Return that value
48            return Poll::Ready(Some(item));
49        }
50        if this.buf_false.is_some() {
51            // There is a value available for the other stream. Wake that stream if possible
52            // and return pending since we can't store multiple values for a stream
53            if let Some(waker) = this.waker_false {
54                waker.wake_by_ref();
55            }
56            return Poll::Pending;
57        }
58        match this.stream.poll_next(cx) {
59            Poll::Ready(Some(item)) => {
60                if (this.predicate)(&item) {
61                    Poll::Ready(Some(item))
62                } else {
63                    // This value is not what we wanted. Store it and notify other partition task if
64                    // it exists
65                    let _ = this.buf_false.replace(item);
66                    if let Some(waker) = this.waker_false {
67                        waker.wake_by_ref();
68                    }
69                    Poll::Pending
70                }
71            }
72            Poll::Ready(None) => {
73                // If the underlying stream is finished, the `false` stream also must be
74                // finished, so wake it in case nothing else polls it
75                if let Some(waker) = this.waker_false {
76                    waker.wake_by_ref();
77                }
78                Poll::Ready(None)
79            }
80            Poll::Pending => Poll::Pending,
81        }
82    }
83
84    fn poll_next_false(
85        self: std::pin::Pin<&mut Self>,
86        cx: &mut std::task::Context<'_>,
87    ) -> std::task::Poll<Option<I>> {
88        let this = self.project();
89        // I think there should only ever be one waker calling the function
90        if this.waker_false.is_none() {
91            *this.waker_false = Some(cx.waker().clone());
92        }
93        if let Some(item) = this.buf_false.take() {
94            // There was already a value in the buffer. Return that value
95            return Poll::Ready(Some(item));
96        }
97        if this.buf_true.is_some() {
98            // There is a value available for the other stream. Wake that stream if possible
99            // and return pending since we can't store multiple values for a stream
100            if let Some(waker) = this.waker_true {
101                waker.wake_by_ref();
102            }
103            return Poll::Pending;
104        }
105        match this.stream.poll_next(cx) {
106            Poll::Ready(Some(item)) => {
107                if (this.predicate)(&item) {
108                    // This value is not what we wanted. Store it and notify other stream if waker
109                    // exists
110                    let _ = this.buf_true.replace(item);
111                    if let Some(waker) = this.waker_true {
112                        waker.wake_by_ref();
113                    }
114                    Poll::Pending
115                } else {
116                    Poll::Ready(Some(item))
117                }
118            }
119            Poll::Ready(None) => {
120                // If the underlying stream is finished, the `true` stream also must be
121                // finished, so wake it in case nothing else polls it
122                if let Some(waker) = this.waker_true {
123                    waker.wake_by_ref();
124                }
125                Poll::Ready(None)
126            }
127            Poll::Pending => Poll::Pending,
128        }
129    }
130}
131
132/// A struct that implements `Stream` which returns the items where the
133/// predicate returns `true`
134pub struct TrueSplitBy<I, S, P> {
135    stream: Arc<Mutex<SplitBy<I, S, P>>>,
136}
137
138impl<I, S, P> TrueSplitBy<I, S, P> {
139    pub(crate) fn new(stream: Arc<Mutex<SplitBy<I, S, P>>>) -> Self {
140        Self { stream }
141    }
142}
143
144impl<I, S, P> Stream for TrueSplitBy<I, S, P>
145where
146    S: Stream<Item = I> + Unpin,
147    P: Fn(&I) -> bool,
148{
149    type Item = I;
150    fn poll_next(
151        self: std::pin::Pin<&mut Self>,
152        cx: &mut std::task::Context<'_>,
153    ) -> std::task::Poll<Option<Self::Item>> {
154        let response = if let Ok(mut guard) = self.stream.try_lock() {
155            SplitBy::poll_next_true(Pin::new(&mut guard), cx)
156        } else {
157            cx.waker().wake_by_ref();
158            Poll::Pending
159        };
160        response
161    }
162}
163
164/// A struct that implements `Stream` which returns the items where the
165/// predicate returns `false`
166pub struct FalseSplitBy<I, S, P> {
167    stream: Arc<Mutex<SplitBy<I, S, P>>>,
168}
169
170impl<I, S, P> FalseSplitBy<I, S, P> {
171    pub(crate) fn new(stream: Arc<Mutex<SplitBy<I, S, P>>>) -> Self {
172        Self { stream }
173    }
174}
175
176impl<I, S, P> Stream for FalseSplitBy<I, S, P>
177where
178    S: Stream<Item = I> + Unpin,
179    P: Fn(&I) -> bool,
180{
181    type Item = I;
182    fn poll_next(
183        self: std::pin::Pin<&mut Self>,
184        cx: &mut std::task::Context<'_>,
185    ) -> std::task::Poll<Option<Self::Item>> {
186        let response = if let Ok(mut guard) = self.stream.try_lock() {
187            SplitBy::poll_next_false(Pin::new(&mut guard), cx)
188        } else {
189            cx.waker().wake_by_ref();
190            Poll::Pending
191        };
192        response
193    }
194}