rustrade_integration/socket/
on_stream_err.rs1use futures::{Sink, Stream};
2use pin_project::pin_project;
3use std::{
4 pin::Pin,
5 task::{Context, Poll, ready},
6};
7
8pub trait StreamErrorHandler<Err> {
10 fn handle(&mut self, error: &Err) -> StreamErrorAction;
12}
13
14impl<Err, F> StreamErrorHandler<Err> for F
15where
16 F: FnMut(&Err) -> StreamErrorAction,
17{
18 #[inline]
19 fn handle(&mut self, error: &Err) -> StreamErrorAction {
20 self(error)
21 }
22}
23
24#[derive(Debug, Copy, Clone, PartialEq)]
26pub enum StreamErrorAction {
27 Continue,
29 Reconnect,
31}
32
33#[derive(Debug)]
39#[pin_project]
40pub struct OnStreamErr<S, ErrHandler> {
41 #[pin]
42 socket: S,
43 on_err: ErrHandler,
44}
45
46impl<S, ErrHandler> OnStreamErr<S, ErrHandler> {
47 pub fn new(socket: S, on_err: ErrHandler) -> Self {
48 Self { socket, on_err }
49 }
50}
51
52impl<S, StOk, StErr, ErrHandler> Stream for OnStreamErr<S, ErrHandler>
53where
54 S: Stream<Item = Result<StOk, StErr>>,
55 ErrHandler: StreamErrorHandler<StErr>,
56{
57 type Item = Result<StOk, StErr>;
58
59 #[inline]
60 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
61 let mut this = self.project();
62
63 let next_ready = ready!(this.socket.as_mut().poll_next(cx));
64
65 let Some(result) = next_ready else {
66 return Poll::Ready(None);
67 };
68
69 match result {
70 Ok(item) => Poll::Ready(Some(Ok(item))),
71 Err(error) => match (this.on_err).handle(&error) {
72 StreamErrorAction::Continue => Poll::Ready(Some(Err(error))),
73 StreamErrorAction::Reconnect => Poll::Ready(None),
74 },
75 }
76 }
77}
78
79impl<St, ErrHandler, Item> Sink<Item> for OnStreamErr<St, ErrHandler>
80where
81 St: Sink<Item>,
82{
83 type Error = St::Error;
84
85 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 self.project().socket.poll_ready(cx)
87 }
88
89 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
90 self.project().socket.start_send(item)
91 }
92
93 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94 self.project().socket.poll_flush(cx)
95 }
96
97 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98 self.project().socket.poll_close(cx)
99 }
100}
101
102#[cfg(test)]
103#[allow(clippy::unwrap_used)] mod tests {
105 use super::*;
106 use futures::StreamExt;
107 use tokio::sync::mpsc;
108 use tokio_stream::wrappers::UnboundedReceiverStream;
109 use tokio_test::{assert_pending, assert_ready};
110
111 type TestError = &'static str;
112
113 #[tokio::test]
114 async fn test_on_stream_err_passes_through_ok() {
115 let waker = futures::task::noop_waker_ref();
116 let mut cx = Context::from_waker(waker);
117
118 let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
119 let rx = UnboundedReceiverStream::new(rx);
120
121 let mut stream = OnStreamErr::new(rx, |_error: &TestError| StreamErrorAction::Continue);
122
123 assert_pending!(stream.poll_next_unpin(&mut cx));
124
125 tx.send(Ok(1)).unwrap();
126 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
127
128 tx.send(Ok(2)).unwrap();
129 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(2)));
130
131 drop(tx);
132 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
133 }
134
135 #[tokio::test]
136 async fn test_on_stream_err_continue_action() {
137 let waker = futures::task::noop_waker_ref();
138 let mut cx = Context::from_waker(waker);
139
140 let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
141 let rx = UnboundedReceiverStream::new(rx);
142
143 let mut stream = OnStreamErr::new(rx, |_error: &TestError| StreamErrorAction::Continue);
144
145 tx.send(Ok(1)).unwrap();
146 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
147
148 tx.send(Err("error1")).unwrap();
149 assert_eq!(
150 assert_ready!(stream.poll_next_unpin(&mut cx)),
151 Some(Err("error1"))
152 );
153
154 tx.send(Ok(2)).unwrap();
155 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(2)));
156
157 drop(tx);
158 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
159 }
160
161 #[tokio::test]
162 async fn test_on_stream_err_reconnect_action() {
163 let waker = futures::task::noop_waker_ref();
164 let mut cx = Context::from_waker(waker);
165
166 let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
167 let rx = UnboundedReceiverStream::new(rx);
168
169 let mut stream = OnStreamErr::new(rx, |error: &TestError| {
170 if *error == "fatal" {
171 StreamErrorAction::Reconnect
172 } else {
173 StreamErrorAction::Continue
174 }
175 });
176
177 tx.send(Ok(1)).unwrap();
178 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
179
180 tx.send(Err("non-fatal")).unwrap();
181 assert_eq!(
182 assert_ready!(stream.poll_next_unpin(&mut cx)),
183 Some(Err("non-fatal"))
184 );
185
186 tx.send(Err("fatal")).unwrap();
187 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
188 }
189}