rustrade_integration/socket/
on_connect_err.rs1use futures::{Sink, Stream};
2use pin_project::pin_project;
3use std::{
4 pin::Pin,
5 task::{Context, Poll, ready},
6};
7
8pub trait ConnectErrorHandler<Err> {
10 fn handle(&mut self, error: &ConnectError<Err>) -> ConnectErrorAction;
12}
13
14impl<Err, F> ConnectErrorHandler<Err> for F
15where
16 F: FnMut(&ConnectError<Err>) -> ConnectErrorAction,
17{
18 #[inline]
19 fn handle(&mut self, error: &ConnectError<Err>) -> ConnectErrorAction {
20 self(error)
21 }
22}
23
24#[derive(Debug, Copy, Clone, PartialEq)]
26pub struct ConnectError<ErrConnect> {
27 pub reconnection_attempt: u32,
28 pub kind: ConnectErrorKind<ErrConnect>,
29}
30
31#[derive(Debug, Copy, Clone, PartialEq)]
33pub enum ConnectErrorKind<ErrConnect> {
34 Connect(ErrConnect),
36 Timeout,
38}
39
40#[derive(Debug, Copy, Clone, PartialEq)]
42pub enum ConnectErrorAction {
43 Reconnect,
45 Terminate,
47}
48
49#[derive(Debug)]
56#[pin_project]
57pub struct OnConnectErr<S, ErrHandler> {
58 #[pin]
59 socket: S,
60 on_err: ErrHandler,
61}
62
63impl<S, ErrHandler> OnConnectErr<S, ErrHandler> {
64 pub fn new(socket: S, on_err: ErrHandler) -> Self {
65 Self { socket, on_err }
66 }
67}
68
69impl<S, Socket, ErrConnect, ErrHandler> Stream for OnConnectErr<S, ErrHandler>
70where
71 S: Stream<Item = Result<Socket, ConnectError<ErrConnect>>>,
72 ErrHandler: ConnectErrorHandler<ErrConnect>,
73{
74 type Item = Socket;
75
76 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
77 let mut this = self.project();
78
79 loop {
80 let next_ready = ready!(this.socket.as_mut().poll_next(cx));
81
82 let Some(result) = next_ready else {
83 return Poll::Ready(None);
84 };
85
86 match result {
87 Ok(socket) => {
88 return Poll::Ready(Some(socket));
89 }
90 Err(error) => {
91 match this.on_err.handle(&error) {
92 ConnectErrorAction::Reconnect => {
93 }
95 ConnectErrorAction::Terminate => {
96 return Poll::Ready(None);
97 }
98 }
99 }
100 }
101 }
102 }
103}
104
105impl<S, ErrHandler, Item> Sink<Item> for OnConnectErr<S, ErrHandler>
106where
107 S: Sink<Item>,
108{
109 type Error = S::Error;
110
111 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112 self.project().socket.poll_ready(cx)
113 }
114
115 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
116 self.project().socket.start_send(item)
117 }
118
119 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 self.project().socket.poll_flush(cx)
121 }
122
123 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124 self.project().socket.poll_close(cx)
125 }
126}
127
128#[cfg(test)]
129#[allow(clippy::unwrap_used)] mod tests {
131 use super::*;
132 use crate::socket::ReconnectingSocket;
133 use futures::StreamExt;
134 use tokio::sync::mpsc;
135 use tokio_stream::wrappers::UnboundedReceiverStream;
136 use tokio_test::{assert_pending, assert_ready_eq};
137
138 type TestSocket = i32;
139 type TestError = &'static str;
140
141 #[tokio::test]
142 async fn test_on_connect_err_passes_through_success() {
143 let waker = futures::task::noop_waker_ref();
144 let mut cx = Context::from_waker(waker);
145
146 let (tx, rx) = mpsc::unbounded_channel::<Result<TestSocket, ConnectError<TestError>>>();
147 let rx = UnboundedReceiverStream::new(rx);
148
149 let mut stream =
150 rx.on_connect_err(|_error: &ConnectError<TestError>| ConnectErrorAction::Reconnect);
151
152 assert_pending!(stream.poll_next_unpin(&mut cx));
153
154 tx.send(Ok(1)).unwrap();
155 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
156
157 tx.send(Ok(2)).unwrap();
158 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(2));
159
160 drop(tx);
161 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
162 }
163
164 #[tokio::test]
165 async fn test_on_connect_err_reconnect_action() {
166 let waker = futures::task::noop_waker_ref();
167 let mut cx = Context::from_waker(waker);
168
169 let (tx, rx) = mpsc::unbounded_channel::<Result<TestSocket, ConnectError<TestError>>>();
170 let rx = UnboundedReceiverStream::new(rx);
171
172 let mut stream =
173 rx.on_connect_err(|_error: &ConnectError<TestError>| ConnectErrorAction::Reconnect);
174
175 tx.send(Err(ConnectError {
176 reconnection_attempt: 1,
177 kind: ConnectErrorKind::Connect("network error"),
178 }))
179 .unwrap();
180 assert_pending!(stream.poll_next_unpin(&mut cx));
181
182 tx.send(Err(ConnectError {
183 reconnection_attempt: 2,
184 kind: ConnectErrorKind::Timeout,
185 }))
186 .unwrap();
187 assert_pending!(stream.poll_next_unpin(&mut cx));
188
189 tx.send(Ok(42)).unwrap();
190 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(42));
191
192 drop(tx);
193 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
194 }
195
196 #[tokio::test]
197 async fn test_on_connect_err_terminate_action() {
198 let waker = futures::task::noop_waker_ref();
199 let mut cx = Context::from_waker(waker);
200
201 let (tx, rx) = mpsc::unbounded_channel::<Result<TestSocket, ConnectError<TestError>>>();
202 let rx = UnboundedReceiverStream::new(rx);
203
204 let mut stream = rx.on_connect_err(|error: &ConnectError<TestError>| {
205 if error.reconnection_attempt >= 3 {
206 ConnectErrorAction::Terminate
207 } else {
208 ConnectErrorAction::Reconnect
209 }
210 });
211
212 tx.send(Ok(1)).unwrap();
213 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
214
215 tx.send(Err(ConnectError {
216 reconnection_attempt: 1,
217 kind: ConnectErrorKind::Connect("error"),
218 }))
219 .unwrap();
220 assert_pending!(stream.poll_next_unpin(&mut cx));
221
222 tx.send(Err(ConnectError {
223 reconnection_attempt: 3,
224 kind: ConnectErrorKind::Connect("error"),
225 }))
226 .unwrap();
227 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
228 }
229}