1use std::error::Error;
2use std::future::Future;
3use std::iter::once;
4use std::marker::PhantomData;
5use std::ops::{Deref, DerefMut};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use futures::{ready, Sink, Stream};
11use log::{debug, error, info};
12
13use crate::config::ReconnectOptions;
14
15pub trait UnderlyingStream<C, I, E>: Sized + Unpin
18where
19 C: Clone + Send + Unpin,
20 E: Error,
21{
22 fn establish(ctor_arg: C) -> Pin<Box<dyn Future<Output = Result<Self, E>> + Send>>;
25
26 fn is_write_disconnect_error(&self, err: &E) -> bool;
30
31 #[allow(unused_variables)]
36 fn is_read_disconnect_error(&self, item: &I) -> bool {
37 false
38 }
39
40 fn exhaust_err() -> E;
42}
43
44struct AttemptsTracker {
45 attempt_num: usize,
46 retries_remaining: Box<dyn Iterator<Item = Duration> + Send>,
47}
48
49struct ReconnectStatus<T, C, I, E> {
50 attempts_tracker: AttemptsTracker,
51 reconnect_attempt: Pin<Box<dyn Future<Output = Result<T, E>> + Send>>,
52 _marker_1: PhantomData<C>,
53 _marker_2: PhantomData<I>,
54 _marker_3: PhantomData<E>,
55}
56
57impl<T, C, I, E> ReconnectStatus<T, C, I, E>
58where
59 T: UnderlyingStream<C, I, E>,
60 C: Clone + Send + Unpin + 'static,
61 E: Error + Unpin,
62{
63 pub fn new(options: &ReconnectOptions) -> Self {
64 ReconnectStatus {
65 attempts_tracker: AttemptsTracker {
66 attempt_num: 0,
67 retries_remaining: (options.retries_to_attempt_fn())(),
68 },
69 reconnect_attempt: Box::pin(async { unreachable!("Not going to happen") }),
70 _marker_1: PhantomData,
71 _marker_2: PhantomData,
72 _marker_3: PhantomData,
73 }
74 }
75}
76
77pub struct ReconnectStream<T, C, I, E> {
81 status: Status<T, C, I, E>,
82 underlying_io: T,
83 options: ReconnectOptions,
84 ctor_arg: C,
85 _marker: PhantomData<I>,
86}
87
88enum Status<T, C, I, E> {
89 Connected,
90 Disconnected(ReconnectStatus<T, C, I, E>),
91 FailedAndExhausted, }
93
94impl<T, C, I, E> Deref for ReconnectStream<T, C, I, E> {
95 type Target = T;
96
97 fn deref(&self) -> &Self::Target {
98 &self.underlying_io
99 }
100}
101
102impl<T, C, I, E> DerefMut for ReconnectStream<T, C, I, E> {
103 fn deref_mut(&mut self) -> &mut Self::Target {
104 &mut self.underlying_io
105 }
106}
107
108impl<T, C, I, E> ReconnectStream<T, C, I, E>
109where
110 T: UnderlyingStream<C, I, E>,
111 C: Clone + Send + Unpin + 'static,
112 I: Unpin,
113 E: Error + Unpin,
114{
115 pub async fn connect(ctor_arg: C) -> Result<Self, E> {
118 let options = ReconnectOptions::new();
119 Self::connect_with_options(ctor_arg, options).await
120 }
121
122 pub async fn connect_with_options(ctor_arg: C, options: ReconnectOptions) -> Result<Self, E> {
123 let tries = (**options.retries_to_attempt_fn())()
124 .map(Some)
125 .chain(once(None));
126 let mut result = None;
127 for (counter, maybe_delay) in tries.enumerate() {
128 match T::establish(ctor_arg.clone()).await {
129 Ok(inner) => {
130 debug!("Initial connection succeeded.");
131 (options.on_connect_callback())();
132 result = Some(Ok(inner));
133 break;
134 }
135 Err(e) => {
136 error!("Connection failed due to: {:?}.", e);
137 (options.on_connect_fail_callback())();
138
139 if options.exit_if_first_connect_fails() {
140 error!("Bailing after initial connection failure.");
141 return Err(e);
142 }
143
144 result = Some(Err(e));
145
146 if let Some(delay) = maybe_delay {
147 debug!(
148 "Will re-perform initial connect attempt #{} in {:?}.",
149 counter + 1,
150 delay
151 );
152
153 #[cfg(feature = "tokio")]
154 let sleep_fut = tokio::time::sleep(delay);
155 #[cfg(feature = "async-std")]
156 let sleep_fut = async_std::task::sleep(delay);
157
158 sleep_fut.await;
159
160 debug!("Attempting reconnect #{} now.", counter + 1);
161 }
162 }
163 }
164 }
165
166 match result.unwrap() {
167 Ok(inner) => Ok(ReconnectStream {
168 status: Status::Connected,
169 ctor_arg,
170 underlying_io: inner,
171 options,
172 _marker: PhantomData,
173 }),
174 Err(e) => {
175 error!("No more re-connect retries remaining. Never able to establish initial connection.");
176 Err(e)
177 }
178 }
179 }
180
181 fn on_disconnect(mut self: Pin<&mut Self>, cx: &mut Context) {
182 match &mut self.status {
183 Status::Connected => {
185 error!("Disconnect occurred");
186 (self.options.on_disconnect_callback())();
187 self.status = Status::Disconnected(ReconnectStatus::new(&self.options));
188 }
189 Status::Disconnected(_) => {
190 (self.options.on_connect_fail_callback())();
191 }
192 Status::FailedAndExhausted => {
193 unreachable!("on_disconnect will not occur for already exhausted state.")
194 }
195 };
196
197 let ctor_arg = self.ctor_arg.clone();
198
199 if let Status::Disconnected(reconnect_status) = &mut self.status {
201 let next_duration = match reconnect_status.attempts_tracker.retries_remaining.next() {
202 Some(duration) => duration,
203 None => {
204 error!("No more re-connect retries remaining. Giving up.");
205 self.status = Status::FailedAndExhausted;
206 cx.waker().wake_by_ref();
207 return;
208 }
209 };
210
211 #[cfg(feature = "tokio")]
212 let future_instant = tokio::time::sleep(next_duration);
213 #[cfg(feature = "async-std")]
214 let future_instant = async_std::task::sleep(next_duration);
215
216 reconnect_status.attempts_tracker.attempt_num += 1;
217 let cur_num = reconnect_status.attempts_tracker.attempt_num;
218
219 let reconnect_attempt = async move {
220 future_instant.await;
221 debug!("Attempting reconnect #{} now.", cur_num);
222 T::establish(ctor_arg).await
223 };
224
225 reconnect_status.reconnect_attempt = Box::pin(reconnect_attempt);
226
227 debug!(
228 "Will perform reconnect attempt #{} in {:?}.",
229 reconnect_status.attempts_tracker.attempt_num, next_duration
230 );
231
232 cx.waker().wake_by_ref();
233 }
234 }
235
236 fn poll_disconnect(mut self: Pin<&mut Self>, cx: &mut Context) {
237 let (attempt, attempt_num) = match &mut self.status {
238 Status::Connected => unreachable!(),
239 Status::Disconnected(ref mut status) => (
240 Pin::new(&mut status.reconnect_attempt),
241 status.attempts_tracker.attempt_num,
242 ),
243 Status::FailedAndExhausted => unreachable!(),
244 };
245
246 match attempt.poll(cx) {
247 Poll::Ready(Ok(underlying_io)) => {
248 info!("Connection re-established");
249 cx.waker().wake_by_ref();
250 self.status = Status::Connected;
251 (self.options.on_connect_callback())();
252 self.underlying_io = underlying_io;
253 }
254 Poll::Ready(Err(err)) => {
255 error!("Connection attempt #{} failed: {:?}", attempt_num, err);
256 self.on_disconnect(cx);
257 }
258 Poll::Pending => {}
259 }
260 }
261
262 fn is_write_disconnect_detected<X>(&self, poll_result: &Poll<Result<X, E>>) -> bool {
263 match poll_result {
264 Poll::Ready(Err(err)) => self.is_write_disconnect_error(err),
265 _ => false,
266 }
267 }
268}
269
270impl<T, C, I, E> Stream for ReconnectStream<T, C, I, E>
271where
272 T: UnderlyingStream<C, I, E> + Stream<Item = I>,
273 C: Clone + Send + Unpin + 'static,
274 I: Unpin,
275 E: Error + Unpin,
276{
277 type Item = I;
278
279 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
280 match self.status {
281 Status::Connected => {
282 let poll = ready!(Pin::new(&mut self.underlying_io).poll_next(cx));
283 if let Some(poll) = poll {
284 if self.is_read_disconnect_error(&poll) {
285 self.on_disconnect(cx);
286 Poll::Pending
287 } else {
288 Poll::Ready(Some(poll))
289 }
290 } else {
291 self.on_disconnect(cx);
292 Poll::Pending
293 }
294 }
295 Status::Disconnected(_) => {
296 self.poll_disconnect(cx);
297 Poll::Pending
298 }
299 Status::FailedAndExhausted => Poll::Ready(None),
300 }
301 }
302
303 fn size_hint(&self) -> (usize, Option<usize>) {
304 self.underlying_io.size_hint()
305 }
306}
307
308impl<T, C, I, I2, E> Sink<I> for ReconnectStream<T, C, I2, E>
309where
310 T: UnderlyingStream<C, I2, E> + Sink<I, Error = E>,
311 C: Clone + Send + Unpin + 'static,
312 I2: Unpin,
313 E: Error + Unpin,
314{
315 type Error = E;
316
317 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
318 match self.status {
319 Status::Connected => {
320 let poll = Pin::new(&mut self.underlying_io).poll_ready(cx);
321
322 if self.is_write_disconnect_detected(&poll) {
323 self.on_disconnect(cx);
324 Poll::Pending
325 } else {
326 poll
327 }
328 }
329 Status::Disconnected(_) => {
330 self.poll_disconnect(cx);
331 Poll::Pending
332 }
333 Status::FailedAndExhausted => Poll::Ready(Err(T::exhaust_err())),
334 }
335 }
336
337 fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
338 Pin::new(&mut self.underlying_io).start_send(item)
339 }
340
341 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
342 match self.status {
343 Status::Connected => {
344 let poll = Pin::new(&mut self.underlying_io).poll_flush(cx);
345
346 if self.is_write_disconnect_detected(&poll) {
347 self.on_disconnect(cx);
348 Poll::Pending
349 } else {
350 poll
351 }
352 }
353 Status::Disconnected(_) => {
354 self.poll_disconnect(cx);
355 Poll::Pending
356 }
357 Status::FailedAndExhausted => Poll::Ready(Err(T::exhaust_err())),
358 }
359 }
360
361 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
362 match self.status {
363 Status::Connected => {
364 let poll = Pin::new(&mut self.underlying_io).poll_close(cx);
365 if poll.is_ready() {
366 self.on_disconnect(cx);
368 }
369
370 poll
371 }
372 Status::Disconnected(_) => Poll::Pending,
373 Status::FailedAndExhausted => Poll::Ready(Err(T::exhaust_err())),
374 }
375 }
376}