torrust_index/web/api/server/
custom_axum.rs

1//! Wrapper for Axum server to add timeouts.
2//!
3//! Copyright (c) Eray Karatay ([@programatik29](https://github.com/programatik29)).
4//!
5//! See: <https://gist.github.com/programatik29/36d371c657392fd7f322e7342957b6d1>.
6//!
7//! If a client opens a HTTP connection and it does not send any requests, the
8//! connection is closed after a timeout. You can test it with:
9//!
10//! ```text
11//! telnet 127.0.0.1 1212
12//! Trying 127.0.0.1...
13//! Connected to 127.0.0.1.
14//! Escape character is '^]'.
15//! Connection closed by foreign host.
16//! ```
17//!
18//! If you want to know more about Axum and timeouts see <https://github.com/josecelano/axum-server-timeout>.
19use std::future::Ready;
20use std::io::ErrorKind;
21use std::net::TcpListener;
22use std::pin::Pin;
23use std::task::{Context, Poll};
24use std::time::Duration;
25
26use axum_server::accept::Accept;
27use axum_server::tls_rustls::{RustlsAcceptor, RustlsConfig};
28use axum_server::Server;
29use futures_util::{ready, Future};
30use http_body::{Body, Frame};
31use hyper::Response;
32use hyper_util::rt::TokioTimer;
33use pin_project_lite::pin_project;
34use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
35use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
36use tokio::time::{Instant, Sleep};
37use tower::Service;
38
39const HTTP1_HEADER_READ_TIMEOUT: Duration = Duration::from_secs(5);
40const HTTP2_KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(5);
41const HTTP2_KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(5);
42
43#[must_use]
44pub fn from_tcp_with_timeouts(socket: TcpListener) -> Server {
45    add_timeouts(axum_server::from_tcp(socket))
46}
47
48#[must_use]
49pub fn from_tcp_rustls_with_timeouts(socket: TcpListener, tls: RustlsConfig) -> Server<RustlsAcceptor> {
50    add_timeouts(axum_server::from_tcp_rustls(socket, tls))
51}
52
53fn add_timeouts<A>(mut server: Server<A>) -> Server<A> {
54    server.http_builder().http1().timer(TokioTimer::new());
55    server.http_builder().http2().timer(TokioTimer::new());
56
57    server.http_builder().http1().header_read_timeout(HTTP1_HEADER_READ_TIMEOUT);
58    server
59        .http_builder()
60        .http2()
61        .keep_alive_timeout(HTTP2_KEEP_ALIVE_TIMEOUT)
62        .keep_alive_interval(HTTP2_KEEP_ALIVE_INTERVAL);
63
64    server
65}
66
67#[derive(Clone)]
68pub struct TimeoutAcceptor;
69
70impl<I, S> Accept<I, S> for TimeoutAcceptor {
71    type Stream = TimeoutStream<I>;
72    type Service = TimeoutService<S>;
73    type Future = Ready<std::io::Result<(Self::Stream, Self::Service)>>;
74
75    fn accept(&self, stream: I, service: S) -> Self::Future {
76        let (tx, rx) = mpsc::unbounded_channel();
77
78        let stream = TimeoutStream::new(stream, HTTP1_HEADER_READ_TIMEOUT, rx);
79        let service = TimeoutService::new(service, tx);
80
81        std::future::ready(Ok((stream, service)))
82    }
83}
84
85#[derive(Clone)]
86pub struct TimeoutService<S> {
87    inner: S,
88    sender: UnboundedSender<TimerSignal>,
89}
90
91impl<S> TimeoutService<S> {
92    fn new(inner: S, sender: UnboundedSender<TimerSignal>) -> Self {
93        Self { inner, sender }
94    }
95}
96
97impl<S, B, Request> Service<Request> for TimeoutService<S>
98where
99    S: Service<Request, Response = Response<B>>,
100{
101    type Response = Response<TimeoutBody<B>>;
102    type Error = S::Error;
103    type Future = TimeoutServiceFuture<S::Future>;
104
105    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106        self.inner.poll_ready(cx)
107    }
108
109    fn call(&mut self, req: Request) -> Self::Future {
110        // send timer wait signal
111        let _ = self.sender.send(TimerSignal::Wait);
112
113        TimeoutServiceFuture::new(self.inner.call(req), self.sender.clone())
114    }
115}
116
117pin_project! {
118    pub struct TimeoutServiceFuture<F> {
119        #[pin]
120        inner: F,
121        sender: Option<UnboundedSender<TimerSignal>>,
122    }
123}
124
125impl<F> TimeoutServiceFuture<F> {
126    fn new(inner: F, sender: UnboundedSender<TimerSignal>) -> Self {
127        Self {
128            inner,
129            sender: Some(sender),
130        }
131    }
132}
133
134impl<F, B, E> Future for TimeoutServiceFuture<F>
135where
136    F: Future<Output = Result<Response<B>, E>>,
137{
138    type Output = Result<Response<TimeoutBody<B>>, E>;
139
140    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141        let this = self.project();
142        this.inner.poll(cx).map(|result| {
143            result.map(|response| {
144                response.map(|body| TimeoutBody::new(body, this.sender.take().expect("future polled after ready")))
145            })
146        })
147    }
148}
149
150enum TimerSignal {
151    Wait,
152    Reset,
153}
154
155pin_project! {
156    pub struct TimeoutBody<B> {
157        #[pin]
158        inner: B,
159        sender: UnboundedSender<TimerSignal>,
160    }
161}
162
163impl<B> TimeoutBody<B> {
164    fn new(inner: B, sender: UnboundedSender<TimerSignal>) -> Self {
165        Self { inner, sender }
166    }
167}
168
169impl<B: Body> Body for TimeoutBody<B> {
170    type Data = B::Data;
171    type Error = B::Error;
172
173    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
174        let this = self.project();
175        let option = ready!(this.inner.poll_frame(cx));
176
177        if option.is_none() {
178            let _ = this.sender.send(TimerSignal::Reset);
179        }
180
181        Poll::Ready(option)
182    }
183
184    fn is_end_stream(&self) -> bool {
185        let is_end_stream = self.inner.is_end_stream();
186
187        if is_end_stream {
188            let _ = self.sender.send(TimerSignal::Reset);
189        }
190
191        is_end_stream
192    }
193
194    fn size_hint(&self) -> http_body::SizeHint {
195        self.inner.size_hint()
196    }
197}
198
199pub struct TimeoutStream<IO> {
200    inner: IO,
201    // hyper requires unpin
202    sleep: Pin<Box<Sleep>>,
203    duration: Duration,
204    waiting: bool,
205    receiver: UnboundedReceiver<TimerSignal>,
206    finished: bool,
207}
208
209impl<IO> TimeoutStream<IO> {
210    fn new(inner: IO, duration: Duration, receiver: UnboundedReceiver<TimerSignal>) -> Self {
211        Self {
212            inner,
213            sleep: Box::pin(tokio::time::sleep(duration)),
214            duration,
215            waiting: false,
216            receiver,
217            finished: false,
218        }
219    }
220}
221
222impl<IO: AsyncRead + Unpin> AsyncRead for TimeoutStream<IO> {
223    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
224        if !self.finished {
225            match Pin::new(&mut self.receiver).poll_recv(cx) {
226                // reset the timer
227                Poll::Ready(Some(TimerSignal::Reset)) => {
228                    self.waiting = false;
229
230                    let deadline = Instant::now() + self.duration;
231                    self.sleep.as_mut().reset(deadline);
232                }
233                // enter waiting mode (for response body last chunk)
234                Poll::Ready(Some(TimerSignal::Wait)) => self.waiting = true,
235                Poll::Ready(None) => self.finished = true,
236                Poll::Pending => (),
237            }
238        }
239
240        if !self.waiting {
241            // return error if timer is elapsed
242            if let Poll::Ready(()) = self.sleep.as_mut().poll(cx) {
243                return Poll::Ready(Err(std::io::Error::new(ErrorKind::TimedOut, "request header read timed out")));
244            }
245        }
246
247        Pin::new(&mut self.inner).poll_read(cx, buf)
248    }
249}
250
251impl<IO: AsyncWrite + Unpin> AsyncWrite for TimeoutStream<IO> {
252    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
253        Pin::new(&mut self.inner).poll_write(cx, buf)
254    }
255
256    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
257        Pin::new(&mut self.inner).poll_flush(cx)
258    }
259
260    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
261        Pin::new(&mut self.inner).poll_shutdown(cx)
262    }
263
264    fn poll_write_vectored(
265        mut self: Pin<&mut Self>,
266        cx: &mut Context<'_>,
267        bufs: &[std::io::IoSlice<'_>],
268    ) -> Poll<Result<usize, std::io::Error>> {
269        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
270    }
271
272    fn is_write_vectored(&self) -> bool {
273        self.inner.is_write_vectored()
274    }
275}