Skip to main content

rustrade_integration/stream/ext/
forward_by.rs

1use futures::{Sink, Stream, ready};
2use pin_project::pin_project;
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8/// Stream adapter that forwards matching items and yields non-matching items.
9///
10/// Uses predicate to split items: Left items are forwarded, Right items are yielded.
11#[derive(Debug)]
12#[pin_project]
13pub struct ForwardBy<S, FnPredicate, FnForward> {
14    #[pin]
15    socket: S,
16    predicate: FnPredicate,
17    forward: FnForward,
18}
19
20impl<S, FnPredicate, FnForward> ForwardBy<S, FnPredicate, FnForward> {
21    pub fn new(socket: S, predicate: FnPredicate, forward: FnForward) -> Self {
22        Self {
23            socket,
24            predicate,
25            forward,
26        }
27    }
28}
29
30impl<S, A, B, FnPredicate, FnForward, FwdErr> Stream for ForwardBy<S, FnPredicate, FnForward>
31where
32    S: Stream,
33    FnPredicate: Fn(S::Item) -> futures::future::Either<A, B>,
34    FnForward: FnMut(A) -> Result<(), FwdErr>,
35{
36    type Item = B;
37
38    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
39        let mut this = self.project();
40
41        loop {
42            let next_ready = ready!(this.socket.as_mut().poll_next(cx));
43
44            let Some(item) = next_ready else {
45                return Poll::Ready(None);
46            };
47
48            match (this.predicate)(item) {
49                futures::future::Either::Left(left) => {
50                    if (this.forward)(left).is_err() {
51                        return Poll::Ready(None);
52                    } else {
53                        // Initiate next poll_next immediately
54                    }
55                }
56                futures::future::Either::Right(right) => return Poll::Ready(Some(right)),
57            }
58        }
59    }
60}
61
62impl<S, FnPredicate, FnForward, Item> Sink<Item> for ForwardBy<S, FnPredicate, FnForward>
63where
64    S: Sink<Item>,
65{
66    type Error = S::Error;
67
68    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
69        self.project().socket.poll_ready(cx)
70    }
71
72    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
73        self.project().socket.start_send(item)
74    }
75
76    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
77        self.project().socket.poll_flush(cx)
78    }
79
80    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
81        self.project().socket.poll_close(cx)
82    }
83}
84
85#[cfg(test)]
86#[allow(clippy::unwrap_used)] // Test code: panics on bad input are acceptable
87mod tests {
88    use super::*;
89    use crate::stream::ext::BarterStreamExt;
90    use futures::{StreamExt, future::Either};
91    use tokio::sync::mpsc;
92    use tokio_stream::wrappers::UnboundedReceiverStream;
93    use tokio_test::{assert_pending, assert_ready_eq};
94
95    #[tokio::test]
96    async fn test_forward_by() {
97        let waker = futures::task::noop_waker_ref();
98        let mut cx = Context::from_waker(waker);
99
100        let (tx, rx) = mpsc::unbounded_channel::<i32>();
101        let rx = UnboundedReceiverStream::new(rx);
102
103        let (forward_tx, mut forward_rx) = mpsc::unbounded_channel::<i32>();
104
105        let mut stream = rx.forward_by(
106            |item| {
107                if item % 2 == 0 {
108                    Either::Left(item)
109                } else {
110                    Either::Right(item)
111                }
112            },
113            move |item| forward_tx.send(item).map_err(|_| ()),
114        );
115
116        assert_pending!(stream.poll_next_unpin(&mut cx));
117
118        tx.send(1).unwrap();
119        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
120        assert!(forward_rx.try_recv().is_err());
121
122        tx.send(2).unwrap();
123        assert_pending!(stream.poll_next_unpin(&mut cx));
124        assert_eq!(forward_rx.try_recv().unwrap(), 2);
125
126        tx.send(3).unwrap();
127        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(3));
128
129        tx.send(4).unwrap();
130        tx.send(5).unwrap();
131        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(5));
132        assert_eq!(forward_rx.try_recv().unwrap(), 4);
133
134        drop(tx);
135        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
136    }
137
138    #[tokio::test]
139    async fn test_forward_by_with_custom_error_type() {
140        let waker = futures::task::noop_waker_ref();
141        let mut cx = Context::from_waker(waker);
142
143        let (tx, rx) = mpsc::unbounded_channel::<i32>();
144        let rx = UnboundedReceiverStream::new(rx);
145
146        let mut stream = rx.forward_by(
147            |item| {
148                if item % 2 == 0 {
149                    Either::Left(item)
150                } else {
151                    Either::Right(item)
152                }
153            },
154            |item: i32| -> Result<(), String> {
155                if item == 4 {
156                    Err(format!("rejected {item}"))
157                } else {
158                    Ok(())
159                }
160            },
161        );
162
163        tx.send(2).unwrap();
164        assert_pending!(stream.poll_next_unpin(&mut cx));
165
166        tx.send(1).unwrap();
167        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
168
169        tx.send(4).unwrap();
170        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
171    }
172
173    #[tokio::test]
174    async fn test_forward_by_terminates_on_forward_error() {
175        let waker = futures::task::noop_waker_ref();
176        let mut cx = Context::from_waker(waker);
177
178        let (tx, rx) = mpsc::unbounded_channel::<i32>();
179        let rx = UnboundedReceiverStream::new(rx);
180
181        let mut stream = rx.forward_by(
182            |item| {
183                if item % 2 == 0 {
184                    Either::Left(item)
185                } else {
186                    Either::Right(item)
187                }
188            },
189            |item| if item == 4 { Err(()) } else { Ok(()) },
190        );
191
192        tx.send(1).unwrap();
193        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
194
195        tx.send(2).unwrap();
196        assert_pending!(stream.poll_next_unpin(&mut cx));
197
198        tx.send(4).unwrap();
199        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
200    }
201}