1use std::{
2 marker::PhantomData,
3 pin::Pin,
4 sync::{Arc, Mutex},
5 task::{Poll, Waker},
6};
7
8use futures::{future::Either, Stream};
9use pin_project::pin_project;
10
11use crate::ring_buf::RingBuf;
12
13#[pin_project]
14pub(crate) struct SplitByMapBuffered<I, L, R, S, P, const N: usize> {
15 buf_left: RingBuf<L, N>,
16 buf_right: RingBuf<R, N>,
17 waker_left: Option<Waker>,
18 waker_right: Option<Waker>,
19 #[pin]
20 stream: S,
21 predicate: P,
22 item: PhantomData<I>,
23}
24
25impl<I, L, R, S, P, const N: usize> SplitByMapBuffered<I, L, R, S, P, N>
26where
27 S: Stream<Item = I>,
28 P: Fn(I) -> Either<L, R>,
29{
30 pub(crate) fn new(stream: S, predicate: P) -> Arc<Mutex<Self>> {
31 Arc::new(Mutex::new(Self {
32 buf_right: RingBuf::new(),
33 buf_left: RingBuf::new(),
34 waker_right: None,
35 waker_left: None,
36 stream,
37 predicate,
38 item: PhantomData,
39 }))
40 }
41
42 fn poll_next_left(
43 self: std::pin::Pin<&mut Self>,
44 cx: &mut std::task::Context<'_>,
45 ) -> std::task::Poll<Option<L>> {
46 let this = self.project();
47 if this.waker_left.is_none() {
49 *this.waker_left = Some(cx.waker().clone());
50 }
51 if let Some(item) = this.buf_left.pop_front() {
52 return Poll::Ready(Some(item));
54 }
55 if this.buf_right.remaining() == 0 {
56 if let Some(waker) = this.waker_right {
59 waker.wake_by_ref();
60 }
61 return Poll::Pending;
62 }
63 match this.stream.poll_next(cx) {
64 Poll::Ready(Some(item)) => {
65 match (this.predicate)(item) {
66 Either::Left(left_item) => Poll::Ready(Some(left_item)),
67 Either::Right(right_item) => {
68 let _ = this.buf_right.push_back(right_item);
71 if let Some(waker) = this.waker_right {
72 waker.wake_by_ref();
73 }
74 Poll::Pending
75 }
76 }
77 }
78 Poll::Ready(None) => {
79 if let Some(waker) = this.waker_right {
82 waker.wake_by_ref();
83 }
84 Poll::Ready(None)
85 }
86 Poll::Pending => Poll::Pending,
87 }
88 }
89
90 fn poll_next_right(
91 self: std::pin::Pin<&mut Self>,
92 cx: &mut std::task::Context<'_>,
93 ) -> std::task::Poll<Option<R>> {
94 let this = self.project();
95 if this.waker_right.is_none() {
97 *this.waker_right = Some(cx.waker().clone());
98 }
99 if let Some(item) = this.buf_right.pop_front() {
100 return Poll::Ready(Some(item));
102 }
103 if this.buf_left.remaining() == 0 {
104 if let Some(waker) = this.waker_left {
107 waker.wake_by_ref();
108 }
109 return Poll::Pending;
110 }
111 match this.stream.poll_next(cx) {
112 Poll::Ready(Some(item)) => {
113 match (this.predicate)(item) {
114 Either::Left(left_item) => {
115 let _ = this.buf_left.push_back(left_item);
118 if let Some(waker) = this.waker_left {
119 waker.wake_by_ref();
120 }
121 Poll::Pending
122 }
123 Either::Right(right_item) => Poll::Ready(Some(right_item)),
124 }
125 }
126 Poll::Ready(None) => {
127 if let Some(waker) = this.waker_left {
130 waker.wake_by_ref();
131 }
132 Poll::Ready(None)
133 }
134 Poll::Pending => Poll::Pending,
135 }
136 }
137}
138
139pub struct LeftSplitByMapBuffered<I, L, R, S, P, const N: usize> {
142 stream: Arc<Mutex<SplitByMapBuffered<I, L, R, S, P, N>>>,
143}
144
145impl<I, L, R, S, P, const N: usize> LeftSplitByMapBuffered<I, L, R, S, P, N> {
146 pub(crate) fn new(stream: Arc<Mutex<SplitByMapBuffered<I, L, R, S, P, N>>>) -> Self {
147 Self { stream }
148 }
149}
150
151impl<I, L, R, S, P, const N: usize> Stream for LeftSplitByMapBuffered<I, L, R, S, P, N>
152where
153 S: Stream<Item = I> + Unpin,
154 P: Fn(I) -> Either<L, R>,
155{
156 type Item = L;
157 fn poll_next(
158 self: std::pin::Pin<&mut Self>,
159 cx: &mut std::task::Context<'_>,
160 ) -> std::task::Poll<Option<Self::Item>> {
161 let response = if let Ok(mut guard) = self.stream.try_lock() {
162 SplitByMapBuffered::poll_next_left(Pin::new(&mut guard), cx)
163 } else {
164 cx.waker().wake_by_ref();
165 Poll::Pending
166 };
167 response
168 }
169}
170
171pub struct RightSplitByMapBuffered<I, L, R, S, P, const N: usize> {
174 stream: Arc<Mutex<SplitByMapBuffered<I, L, R, S, P, N>>>,
175}
176
177impl<I, L, R, S, P, const N: usize> RightSplitByMapBuffered<I, L, R, S, P, N> {
178 pub(crate) fn new(stream: Arc<Mutex<SplitByMapBuffered<I, L, R, S, P, N>>>) -> Self {
179 Self { stream }
180 }
181}
182
183impl<I, L, R, S, P, const N: usize> Stream for RightSplitByMapBuffered<I, L, R, S, P, N>
184where
185 S: Stream<Item = I> + Unpin,
186 P: Fn(I) -> Either<L, R>,
187{
188 type Item = R;
189 fn poll_next(
190 self: std::pin::Pin<&mut Self>,
191 cx: &mut std::task::Context<'_>,
192 ) -> std::task::Poll<Option<Self::Item>> {
193 let response = if let Ok(mut guard) = self.stream.try_lock() {
194 SplitByMapBuffered::poll_next_right(Pin::new(&mut guard), cx)
195 } else {
196 cx.waker().wake_by_ref();
197 Poll::Pending
198 };
199 response
200 }
201}