split_stream_by/
split_by_map_buffered.rs

1use std::{
2    marker::PhantomData,
3    pin::Pin,
4    sync::{Arc, Mutex},
5    task::{Poll, Waker},
6};
7
8use futures::{future::Either, Stream};
9use pin_project::pin_project;
10
11use crate::ring_buf::RingBuf;
12
13#[pin_project]
14pub(crate) struct SplitByMapBuffered<I, L, R, S, P, const N: usize> {
15    buf_left: RingBuf<L, N>,
16    buf_right: RingBuf<R, N>,
17    waker_left: Option<Waker>,
18    waker_right: Option<Waker>,
19    #[pin]
20    stream: S,
21    predicate: P,
22    item: PhantomData<I>,
23}
24
25impl<I, L, R, S, P, const N: usize> SplitByMapBuffered<I, L, R, S, P, N>
26where
27    S: Stream<Item = I>,
28    P: Fn(I) -> Either<L, R>,
29{
30    pub(crate) fn new(stream: S, predicate: P) -> Arc<Mutex<Self>> {
31        Arc::new(Mutex::new(Self {
32            buf_right: RingBuf::new(),
33            buf_left: RingBuf::new(),
34            waker_right: None,
35            waker_left: None,
36            stream,
37            predicate,
38            item: PhantomData,
39        }))
40    }
41
42    fn poll_next_left(
43        self: std::pin::Pin<&mut Self>,
44        cx: &mut std::task::Context<'_>,
45    ) -> std::task::Poll<Option<L>> {
46        let this = self.project();
47        // There should only ever be one waker calling the function
48        if this.waker_left.is_none() {
49            *this.waker_left = Some(cx.waker().clone());
50        }
51        if let Some(item) = this.buf_left.pop_front() {
52            // There was already a value in the buffer. Return that value
53            return Poll::Ready(Some(item));
54        }
55        if this.buf_right.remaining() == 0 {
56            // There is a value available for the other stream. Wake that stream if possible
57            // and return pending since we can't store multiple values for a stream
58            if let Some(waker) = this.waker_right {
59                waker.wake_by_ref();
60            }
61            return Poll::Pending;
62        }
63        match this.stream.poll_next(cx) {
64            Poll::Ready(Some(item)) => {
65                match (this.predicate)(item) {
66                    Either::Left(left_item) => Poll::Ready(Some(left_item)),
67                    Either::Right(right_item) => {
68                        // This value is not what we wanted. Store it and notify other partition
69                        // task if it exists
70                        let _ = this.buf_right.push_back(right_item);
71                        if let Some(waker) = this.waker_right {
72                            waker.wake_by_ref();
73                        }
74                        Poll::Pending
75                    }
76                }
77            }
78            Poll::Ready(None) => {
79                // If the underlying stream is finished, the `right` stream also must be
80                // finished, so wake it in case nothing else polls it
81                if let Some(waker) = this.waker_right {
82                    waker.wake_by_ref();
83                }
84                Poll::Ready(None)
85            }
86            Poll::Pending => Poll::Pending,
87        }
88    }
89
90    fn poll_next_right(
91        self: std::pin::Pin<&mut Self>,
92        cx: &mut std::task::Context<'_>,
93    ) -> std::task::Poll<Option<R>> {
94        let this = self.project();
95        // I think there should only ever be one waker calling the function
96        if this.waker_right.is_none() {
97            *this.waker_right = Some(cx.waker().clone());
98        }
99        if let Some(item) = this.buf_right.pop_front() {
100            // There was already a value in the buffer. Return that value
101            return Poll::Ready(Some(item));
102        }
103        if this.buf_left.remaining() == 0 {
104            // There is a value available for the other stream. Wake that stream if possible
105            // and return pending since we can't store multiple values for a stream
106            if let Some(waker) = this.waker_left {
107                waker.wake_by_ref();
108            }
109            return Poll::Pending;
110        }
111        match this.stream.poll_next(cx) {
112            Poll::Ready(Some(item)) => {
113                match (this.predicate)(item) {
114                    Either::Left(left_item) => {
115                        // This value is not what we wanted. Store it and notify other partition
116                        // task if it exists
117                        let _ = this.buf_left.push_back(left_item);
118                        if let Some(waker) = this.waker_left {
119                            waker.wake_by_ref();
120                        }
121                        Poll::Pending
122                    }
123                    Either::Right(right_item) => Poll::Ready(Some(right_item)),
124                }
125            }
126            Poll::Ready(None) => {
127                // If the underlying stream is finished, the `left` stream also must be
128                // finished, so wake it in case nothing else polls it
129                if let Some(waker) = this.waker_left {
130                    waker.wake_by_ref();
131                }
132                Poll::Ready(None)
133            }
134            Poll::Pending => Poll::Pending,
135        }
136    }
137}
138
139/// A struct that implements `Stream` which returns the inner values where
140/// the predicate returns `Either::Left(..)` when using `split_by_map`
141pub struct LeftSplitByMapBuffered<I, L, R, S, P, const N: usize> {
142    stream: Arc<Mutex<SplitByMapBuffered<I, L, R, S, P, N>>>,
143}
144
145impl<I, L, R, S, P, const N: usize> LeftSplitByMapBuffered<I, L, R, S, P, N> {
146    pub(crate) fn new(stream: Arc<Mutex<SplitByMapBuffered<I, L, R, S, P, N>>>) -> Self {
147        Self { stream }
148    }
149}
150
151impl<I, L, R, S, P, const N: usize> Stream for LeftSplitByMapBuffered<I, L, R, S, P, N>
152where
153    S: Stream<Item = I> + Unpin,
154    P: Fn(I) -> Either<L, R>,
155{
156    type Item = L;
157    fn poll_next(
158        self: std::pin::Pin<&mut Self>,
159        cx: &mut std::task::Context<'_>,
160    ) -> std::task::Poll<Option<Self::Item>> {
161        let response = if let Ok(mut guard) = self.stream.try_lock() {
162            SplitByMapBuffered::poll_next_left(Pin::new(&mut guard), cx)
163        } else {
164            cx.waker().wake_by_ref();
165            Poll::Pending
166        };
167        response
168    }
169}
170
171/// A struct that implements `Stream` which returns the inner values where
172/// the predicate returns `Either::Right(..)` when using `split_by_map`
173pub struct RightSplitByMapBuffered<I, L, R, S, P, const N: usize> {
174    stream: Arc<Mutex<SplitByMapBuffered<I, L, R, S, P, N>>>,
175}
176
177impl<I, L, R, S, P, const N: usize> RightSplitByMapBuffered<I, L, R, S, P, N> {
178    pub(crate) fn new(stream: Arc<Mutex<SplitByMapBuffered<I, L, R, S, P, N>>>) -> Self {
179        Self { stream }
180    }
181}
182
183impl<I, L, R, S, P, const N: usize> Stream for RightSplitByMapBuffered<I, L, R, S, P, N>
184where
185    S: Stream<Item = I> + Unpin,
186    P: Fn(I) -> Either<L, R>,
187{
188    type Item = R;
189    fn poll_next(
190        self: std::pin::Pin<&mut Self>,
191        cx: &mut std::task::Context<'_>,
192    ) -> std::task::Poll<Option<Self::Item>> {
193        let response = if let Ok(mut guard) = self.stream.try_lock() {
194            SplitByMapBuffered::poll_next_right(Pin::new(&mut guard), cx)
195        } else {
196            cx.waker().wake_by_ref();
197            Poll::Pending
198        };
199        response
200    }
201}