volans_core/transport/
timeout.rs1use std::{
2 error, fmt,
3 marker::PhantomData,
4 pin::Pin,
5 task::{Context, Poll},
6 time::Duration,
7};
8
9use futures::TryFuture;
10use futures_timer::Delay;
11
12use crate::{Listener, Multiaddr, Transport, TransportError};
13
14#[derive(Debug, Clone)]
15pub struct Timeout<T> {
16 inner: T,
17 outgoing_timeout: Duration,
18 incoming_timeout: Duration,
19}
20
21impl<T> Timeout<T> {
22 pub fn new(inner: T, timeout: Duration) -> Self {
23 Self {
24 inner,
25 outgoing_timeout: timeout,
26 incoming_timeout: timeout,
27 }
28 }
29
30 pub fn outgoing_timeout(&self) -> Duration {
31 self.outgoing_timeout
32 }
33
34 pub fn incoming_timeout(&self) -> Duration {
35 self.incoming_timeout
36 }
37
38 pub fn with_outgoing_timeout(mut self, timeout: Duration) -> Self {
39 self.outgoing_timeout = timeout;
40 self
41 }
42
43 pub fn with_incoming_timeout(mut self, timeout: Duration) -> Self {
44 self.incoming_timeout = timeout;
45 self
46 }
47}
48
49impl<T> Transport for Timeout<T>
50where
51 T: Transport,
52 T::Error: 'static,
53{
54 type Output = T::Output;
55 type Error = TimeoutError<T::Error>;
56 type Dial = TimeoutFuture<T::Dial>;
57 type Incoming = TimeoutFuture<T::Incoming>;
58 type Listener = TimeoutListener<T>;
59
60 fn dial(&self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
61 let fut = self
62 .inner
63 .dial(addr)
64 .map_err(|e| e.map(TimeoutError::Other))?;
65 Ok(TimeoutFuture {
66 inner: fut,
67 timer: Delay::new(self.outgoing_timeout),
68 })
69 }
70
71 fn listen(&self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
72 let listener = self
73 .inner
74 .listen(addr)
75 .map_err(|e| e.map(TimeoutError::Other))?;
76 Ok(TimeoutListener {
77 inner: listener,
78 timeout: self.incoming_timeout,
79 _marker: PhantomData,
80 })
81 }
82}
83
84#[pin_project::pin_project]
85pub struct TimeoutListener<T>
86where
87 T: Transport,
88{
89 #[pin]
90 inner: T::Listener,
91 timeout: Duration,
92 _marker: PhantomData<T>,
93}
94
95impl<T> Listener for TimeoutListener<T>
96where
97 T: Transport,
98 T::Error: 'static,
99{
100 type Output = T::Output;
101 type Error = TimeoutError<T::Error>;
102 type Upgrade = TimeoutFuture<T::Incoming>;
103
104 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105 let this = self.project();
106 this.inner.poll_close(cx).map_err(TimeoutError::Other)
107 }
108
109 fn poll_event(
110 self: Pin<&mut Self>,
111 cx: &mut Context<'_>,
112 ) -> Poll<super::ListenerEvent<Self::Upgrade, Self::Error>> {
113 let this = self.project();
114 let timeout = *this.timeout;
115 this.inner.poll_event(cx).map(|event| {
116 event
117 .map_upgrade(move |u| TimeoutFuture {
118 inner: u,
119 timer: Delay::new(timeout),
120 })
121 .map_err(TimeoutError::Other)
122 })
123 }
124}
125
126#[pin_project::pin_project]
127pub struct TimeoutFuture<TFut> {
128 #[pin]
129 inner: TFut,
130 timer: Delay,
131}
132
133impl<TFut> Future for TimeoutFuture<TFut>
134where
135 TFut: TryFuture,
136{
137 type Output = Result<TFut::Ok, TimeoutError<TFut::Error>>;
138
139 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
140 let mut this = self.project();
141
142 match TryFuture::try_poll(this.inner, cx) {
143 Poll::Pending => {}
144 Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
145 Poll::Ready(Err(err)) => return Poll::Ready(Err(TimeoutError::Other(err))),
146 }
147 match Pin::new(&mut this.timer).poll(cx) {
149 Poll::Pending => Poll::Pending,
150 Poll::Ready(()) => Poll::Ready(Err(TimeoutError::Timeout)),
151 }
152 }
153}
154
155#[derive(Debug)]
156pub enum TimeoutError<TErr> {
157 Timeout,
158 Other(TErr),
159}
160
161impl<TErr> fmt::Display for TimeoutError<TErr>
162where
163 TErr: fmt::Display,
164{
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 match self {
167 TimeoutError::Timeout => write!(f, "Operation timed out"),
168 TimeoutError::Other(err) => write!(f, "Other error: {}", err),
169 }
170 }
171}
172
173impl<TErr> error::Error for TimeoutError<TErr>
174where
175 TErr: error::Error + 'static,
176{
177 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
178 match self {
179 TimeoutError::Timeout => None,
180 TimeoutError::Other(err) => Some(err),
181 }
182 }
183}