rustrade_integration/stream/ext/
forward_by.rs1use futures::{Sink, Stream, ready};
2use pin_project::pin_project;
3use std::{
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8#[derive(Debug)]
12#[pin_project]
13pub struct ForwardBy<S, FnPredicate, FnForward> {
14 #[pin]
15 socket: S,
16 predicate: FnPredicate,
17 forward: FnForward,
18}
19
20impl<S, FnPredicate, FnForward> ForwardBy<S, FnPredicate, FnForward> {
21 pub fn new(socket: S, predicate: FnPredicate, forward: FnForward) -> Self {
22 Self {
23 socket,
24 predicate,
25 forward,
26 }
27 }
28}
29
30impl<S, A, B, FnPredicate, FnForward, FwdErr> Stream for ForwardBy<S, FnPredicate, FnForward>
31where
32 S: Stream,
33 FnPredicate: Fn(S::Item) -> futures::future::Either<A, B>,
34 FnForward: FnMut(A) -> Result<(), FwdErr>,
35{
36 type Item = B;
37
38 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
39 let mut this = self.project();
40
41 loop {
42 let next_ready = ready!(this.socket.as_mut().poll_next(cx));
43
44 let Some(item) = next_ready else {
45 return Poll::Ready(None);
46 };
47
48 match (this.predicate)(item) {
49 futures::future::Either::Left(left) => {
50 if (this.forward)(left).is_err() {
51 return Poll::Ready(None);
52 } else {
53 }
55 }
56 futures::future::Either::Right(right) => return Poll::Ready(Some(right)),
57 }
58 }
59 }
60}
61
62impl<S, FnPredicate, FnForward, Item> Sink<Item> for ForwardBy<S, FnPredicate, FnForward>
63where
64 S: Sink<Item>,
65{
66 type Error = S::Error;
67
68 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
69 self.project().socket.poll_ready(cx)
70 }
71
72 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
73 self.project().socket.start_send(item)
74 }
75
76 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
77 self.project().socket.poll_flush(cx)
78 }
79
80 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
81 self.project().socket.poll_close(cx)
82 }
83}
84
85#[cfg(test)]
86#[allow(clippy::unwrap_used)] mod tests {
88 use super::*;
89 use crate::stream::ext::BarterStreamExt;
90 use futures::{StreamExt, future::Either};
91 use tokio::sync::mpsc;
92 use tokio_stream::wrappers::UnboundedReceiverStream;
93 use tokio_test::{assert_pending, assert_ready_eq};
94
95 #[tokio::test]
96 async fn test_forward_by() {
97 let waker = futures::task::noop_waker_ref();
98 let mut cx = Context::from_waker(waker);
99
100 let (tx, rx) = mpsc::unbounded_channel::<i32>();
101 let rx = UnboundedReceiverStream::new(rx);
102
103 let (forward_tx, mut forward_rx) = mpsc::unbounded_channel::<i32>();
104
105 let mut stream = rx.forward_by(
106 |item| {
107 if item % 2 == 0 {
108 Either::Left(item)
109 } else {
110 Either::Right(item)
111 }
112 },
113 move |item| forward_tx.send(item).map_err(|_| ()),
114 );
115
116 assert_pending!(stream.poll_next_unpin(&mut cx));
117
118 tx.send(1).unwrap();
119 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
120 assert!(forward_rx.try_recv().is_err());
121
122 tx.send(2).unwrap();
123 assert_pending!(stream.poll_next_unpin(&mut cx));
124 assert_eq!(forward_rx.try_recv().unwrap(), 2);
125
126 tx.send(3).unwrap();
127 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(3));
128
129 tx.send(4).unwrap();
130 tx.send(5).unwrap();
131 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(5));
132 assert_eq!(forward_rx.try_recv().unwrap(), 4);
133
134 drop(tx);
135 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
136 }
137
138 #[tokio::test]
139 async fn test_forward_by_with_custom_error_type() {
140 let waker = futures::task::noop_waker_ref();
141 let mut cx = Context::from_waker(waker);
142
143 let (tx, rx) = mpsc::unbounded_channel::<i32>();
144 let rx = UnboundedReceiverStream::new(rx);
145
146 let mut stream = rx.forward_by(
147 |item| {
148 if item % 2 == 0 {
149 Either::Left(item)
150 } else {
151 Either::Right(item)
152 }
153 },
154 |item: i32| -> Result<(), String> {
155 if item == 4 {
156 Err(format!("rejected {item}"))
157 } else {
158 Ok(())
159 }
160 },
161 );
162
163 tx.send(2).unwrap();
164 assert_pending!(stream.poll_next_unpin(&mut cx));
165
166 tx.send(1).unwrap();
167 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
168
169 tx.send(4).unwrap();
170 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
171 }
172
173 #[tokio::test]
174 async fn test_forward_by_terminates_on_forward_error() {
175 let waker = futures::task::noop_waker_ref();
176 let mut cx = Context::from_waker(waker);
177
178 let (tx, rx) = mpsc::unbounded_channel::<i32>();
179 let rx = UnboundedReceiverStream::new(rx);
180
181 let mut stream = rx.forward_by(
182 |item| {
183 if item % 2 == 0 {
184 Either::Left(item)
185 } else {
186 Either::Right(item)
187 }
188 },
189 |item| if item == 4 { Err(()) } else { Ok(()) },
190 );
191
192 tx.send(1).unwrap();
193 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
194
195 tx.send(2).unwrap();
196 assert_pending!(stream.poll_next_unpin(&mut cx));
197
198 tx.send(4).unwrap();
199 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
200 }
201}