split_stream_by/
split_by_map.rs

1use std::{
2    marker::PhantomData,
3    pin::Pin,
4    sync::{Arc, Mutex},
5    task::{Poll, Waker},
6};
7
8use futures::{future::Either, Stream};
9use pin_project::pin_project;
10
11#[pin_project]
12pub(crate) struct SplitByMap<I, L, R, S, P> {
13    buf_left: Option<L>,
14    buf_right: Option<R>,
15    waker_left: Option<Waker>,
16    waker_right: Option<Waker>,
17    #[pin]
18    stream: S,
19    predicate: P,
20    item: PhantomData<I>,
21}
22
23impl<I, L, R, S, P> SplitByMap<I, L, R, S, P>
24where
25    S: Stream<Item = I>,
26    P: Fn(I) -> Either<L, R>,
27{
28    pub(crate) fn new(stream: S, predicate: P) -> Arc<Mutex<Self>> {
29        Arc::new(Mutex::new(Self {
30            buf_right: None,
31            buf_left: None,
32            waker_right: None,
33            waker_left: None,
34            stream,
35            predicate,
36            item: PhantomData,
37        }))
38    }
39
40    fn poll_next_left(
41        self: std::pin::Pin<&mut Self>,
42        cx: &mut std::task::Context<'_>,
43    ) -> std::task::Poll<Option<L>> {
44        let this = self.project();
45        // There should only ever be one waker calling the function
46        if this.waker_left.is_none() {
47            *this.waker_left = Some(cx.waker().clone());
48        }
49        if let Some(item) = this.buf_left.take() {
50            // There was already a value in the buffer. Return that value
51            return Poll::Ready(Some(item));
52        }
53        if this.buf_right.is_some() {
54            // There is a value available for the other stream. Wake that stream if possible
55            // and return pending since we can't store multiple values for a stream
56            if let Some(waker) = this.waker_right {
57                waker.wake_by_ref();
58            }
59            return Poll::Pending;
60        }
61        match this.stream.poll_next(cx) {
62            Poll::Ready(Some(item)) => {
63                match (this.predicate)(item) {
64                    Either::Left(left_item) => Poll::Ready(Some(left_item)),
65                    Either::Right(right_item) => {
66                        // This value is not what we wanted. Store it and notify other partition
67                        // task if it exists
68                        let _ = this.buf_right.replace(right_item);
69                        if let Some(waker) = this.waker_right {
70                            waker.wake_by_ref();
71                        }
72                        Poll::Pending
73                    }
74                }
75            }
76            Poll::Ready(None) => {
77                // If the underlying stream is finished, the `right` stream also must be
78                // finished, so wake it in case nothing else polls it
79                if let Some(waker) = this.waker_right {
80                    waker.wake_by_ref();
81                }
82                Poll::Ready(None)
83            }
84            Poll::Pending => Poll::Pending,
85        }
86    }
87
88    fn poll_next_right(
89        self: std::pin::Pin<&mut Self>,
90        cx: &mut std::task::Context<'_>,
91    ) -> std::task::Poll<Option<R>> {
92        let this = self.project();
93        // I think there should only ever be one waker calling the function
94        if this.waker_right.is_none() {
95            *this.waker_right = Some(cx.waker().clone());
96        }
97        if let Some(item) = this.buf_right.take() {
98            // There was already a value in the buffer. Return that value
99            return Poll::Ready(Some(item));
100        }
101        if this.buf_left.is_some() {
102            // There is a value available for the other stream. Wake that stream if possible
103            // and return pending since we can't store multiple values for a stream
104            if let Some(waker) = this.waker_left {
105                waker.wake_by_ref();
106            }
107            return Poll::Pending;
108        }
109        match this.stream.poll_next(cx) {
110            Poll::Ready(Some(item)) => {
111                match (this.predicate)(item) {
112                    Either::Left(left_item) => {
113                        // This value is not what we wanted. Store it and notify other partition
114                        // task if it exists
115                        let _ = this.buf_left.replace(left_item);
116                        if let Some(waker) = this.waker_left {
117                            waker.wake_by_ref();
118                        }
119                        Poll::Pending
120                    }
121                    Either::Right(right_item) => Poll::Ready(Some(right_item)),
122                }
123            }
124            Poll::Ready(None) => {
125                // If the underlying stream is finished, the `left` stream also must be
126                // finished, so wake it in case nothing else polls it
127                if let Some(waker) = this.waker_left {
128                    waker.wake_by_ref();
129                }
130                Poll::Ready(None)
131            }
132            Poll::Pending => Poll::Pending,
133        }
134    }
135}
136
137/// A struct that implements `Stream` which returns the inner values where
138/// the predicate returns `Either::Left(..)` when using `split_by_map`
139pub struct LeftSplitByMap<I, L, R, S, P> {
140    stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>,
141}
142
143impl<I, L, R, S, P> LeftSplitByMap<I, L, R, S, P> {
144    pub(crate) fn new(stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>) -> Self {
145        Self { stream }
146    }
147}
148
149impl<I, L, R, S, P> Stream for LeftSplitByMap<I, L, R, S, P>
150where
151    S: Stream<Item = I> + Unpin,
152    P: Fn(I) -> Either<L, R>,
153{
154    type Item = L;
155    fn poll_next(
156        self: std::pin::Pin<&mut Self>,
157        cx: &mut std::task::Context<'_>,
158    ) -> std::task::Poll<Option<Self::Item>> {
159        let response = if let Ok(mut guard) = self.stream.try_lock() {
160            SplitByMap::poll_next_left(Pin::new(&mut guard), cx)
161        } else {
162            cx.waker().wake_by_ref();
163            Poll::Pending
164        };
165        response
166    }
167}
168
169/// A struct that implements `Stream` which returns the inner values where
170/// the predicate returns `Either::Right(..)` when using `split_by_map`
171pub struct RightSplitByMap<I, L, R, S, P> {
172    stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>,
173}
174
175impl<I, L, R, S, P> RightSplitByMap<I, L, R, S, P> {
176    pub(crate) fn new(stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>) -> Self {
177        Self { stream }
178    }
179}
180
181impl<I, L, R, S, P> Stream for RightSplitByMap<I, L, R, S, P>
182where
183    S: Stream<Item = I> + Unpin,
184    P: Fn(I) -> Either<L, R>,
185{
186    type Item = R;
187    fn poll_next(
188        self: std::pin::Pin<&mut Self>,
189        cx: &mut std::task::Context<'_>,
190    ) -> std::task::Poll<Option<Self::Item>> {
191        let response = if let Ok(mut guard) = self.stream.try_lock() {
192            SplitByMap::poll_next_right(Pin::new(&mut guard), cx)
193        } else {
194            cx.waker().wake_by_ref();
195            Poll::Pending
196        };
197        response
198    }
199}