Skip to main content

volans_core/transport/
timeout.rs

1use std::{
2    error, fmt,
3    marker::PhantomData,
4    pin::Pin,
5    task::{Context, Poll},
6    time::Duration,
7};
8
9use futures::TryFuture;
10use futures_timer::Delay;
11
12use crate::{Listener, Multiaddr, Transport, TransportError};
13
14#[derive(Debug, Clone)]
15pub struct Timeout<T> {
16    inner: T,
17    outgoing_timeout: Duration,
18    incoming_timeout: Duration,
19}
20
21impl<T> Timeout<T> {
22    pub fn new(inner: T, timeout: Duration) -> Self {
23        Self {
24            inner,
25            outgoing_timeout: timeout,
26            incoming_timeout: timeout,
27        }
28    }
29
30    pub fn outgoing_timeout(&self) -> Duration {
31        self.outgoing_timeout
32    }
33
34    pub fn incoming_timeout(&self) -> Duration {
35        self.incoming_timeout
36    }
37
38    pub fn with_outgoing_timeout(mut self, timeout: Duration) -> Self {
39        self.outgoing_timeout = timeout;
40        self
41    }
42
43    pub fn with_incoming_timeout(mut self, timeout: Duration) -> Self {
44        self.incoming_timeout = timeout;
45        self
46    }
47}
48
49impl<T> Transport for Timeout<T>
50where
51    T: Transport,
52    T::Error: 'static,
53{
54    type Output = T::Output;
55    type Error = TimeoutError<T::Error>;
56    type Dial = TimeoutFuture<T::Dial>;
57    type Incoming = TimeoutFuture<T::Incoming>;
58    type Listener = TimeoutListener<T>;
59
60    fn dial(&self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
61        let fut = self
62            .inner
63            .dial(addr)
64            .map_err(|e| e.map(TimeoutError::Other))?;
65        Ok(TimeoutFuture {
66            inner: fut,
67            timer: Delay::new(self.outgoing_timeout),
68        })
69    }
70
71    fn listen(&self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
72        let listener = self
73            .inner
74            .listen(addr)
75            .map_err(|e| e.map(TimeoutError::Other))?;
76        Ok(TimeoutListener {
77            inner: listener,
78            timeout: self.incoming_timeout,
79            _marker: PhantomData,
80        })
81    }
82}
83
84#[pin_project::pin_project]
85pub struct TimeoutListener<T>
86where
87    T: Transport,
88{
89    #[pin]
90    inner: T::Listener,
91    timeout: Duration,
92    _marker: PhantomData<T>,
93}
94
95impl<T> Listener for TimeoutListener<T>
96where
97    T: Transport,
98    T::Error: 'static,
99{
100    type Output = T::Output;
101    type Error = TimeoutError<T::Error>;
102    type Upgrade = TimeoutFuture<T::Incoming>;
103
104    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105        let this = self.project();
106        this.inner.poll_close(cx).map_err(TimeoutError::Other)
107    }
108
109    fn poll_event(
110        self: Pin<&mut Self>,
111        cx: &mut Context<'_>,
112    ) -> Poll<super::ListenerEvent<Self::Upgrade, Self::Error>> {
113        let this = self.project();
114        let timeout = *this.timeout;
115        this.inner.poll_event(cx).map(|event| {
116            event
117                .map_upgrade(move |u| TimeoutFuture {
118                    inner: u,
119                    timer: Delay::new(timeout),
120                })
121                .map_err(TimeoutError::Other)
122        })
123    }
124}
125
126#[pin_project::pin_project]
127pub struct TimeoutFuture<TFut> {
128    #[pin]
129    inner: TFut,
130    timer: Delay,
131}
132
133impl<TFut> Future for TimeoutFuture<TFut>
134where
135    TFut: TryFuture,
136{
137    type Output = Result<TFut::Ok, TimeoutError<TFut::Error>>;
138
139    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
140        let mut this = self.project();
141
142        match TryFuture::try_poll(this.inner, cx) {
143            Poll::Pending => {}
144            Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
145            Poll::Ready(Err(err)) => return Poll::Ready(Err(TimeoutError::Other(err))),
146        }
147        // 检查是否超时
148        match Pin::new(&mut this.timer).poll(cx) {
149            Poll::Pending => Poll::Pending,
150            Poll::Ready(()) => Poll::Ready(Err(TimeoutError::Timeout)),
151        }
152    }
153}
154
155#[derive(Debug)]
156pub enum TimeoutError<TErr> {
157    Timeout,
158    Other(TErr),
159}
160
161impl<TErr> fmt::Display for TimeoutError<TErr>
162where
163    TErr: fmt::Display,
164{
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        match self {
167            TimeoutError::Timeout => write!(f, "Operation timed out"),
168            TimeoutError::Other(err) => write!(f, "Other error: {}", err),
169        }
170    }
171}
172
173impl<TErr> error::Error for TimeoutError<TErr>
174where
175    TErr: error::Error + 'static,
176{
177    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
178        match self {
179            TimeoutError::Timeout => None,
180            TimeoutError::Other(err) => Some(err),
181        }
182    }
183}