torrust_index/web/api/server/
custom_axum.rs1use std::future::Ready;
20use std::io::ErrorKind;
21use std::net::TcpListener;
22use std::pin::Pin;
23use std::task::{Context, Poll};
24use std::time::Duration;
25
26use axum_server::accept::Accept;
27use axum_server::tls_rustls::{RustlsAcceptor, RustlsConfig};
28use axum_server::Server;
29use futures_util::{ready, Future};
30use http_body::{Body, Frame};
31use hyper::Response;
32use hyper_util::rt::TokioTimer;
33use pin_project_lite::pin_project;
34use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
35use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
36use tokio::time::{Instant, Sleep};
37use tower::Service;
38
39const HTTP1_HEADER_READ_TIMEOUT: Duration = Duration::from_secs(5);
40const HTTP2_KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(5);
41const HTTP2_KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(5);
42
43#[must_use]
44pub fn from_tcp_with_timeouts(socket: TcpListener) -> Server {
45 add_timeouts(axum_server::from_tcp(socket))
46}
47
48#[must_use]
49pub fn from_tcp_rustls_with_timeouts(socket: TcpListener, tls: RustlsConfig) -> Server<RustlsAcceptor> {
50 add_timeouts(axum_server::from_tcp_rustls(socket, tls))
51}
52
53fn add_timeouts<A>(mut server: Server<A>) -> Server<A> {
54 server.http_builder().http1().timer(TokioTimer::new());
55 server.http_builder().http2().timer(TokioTimer::new());
56
57 server.http_builder().http1().header_read_timeout(HTTP1_HEADER_READ_TIMEOUT);
58 server
59 .http_builder()
60 .http2()
61 .keep_alive_timeout(HTTP2_KEEP_ALIVE_TIMEOUT)
62 .keep_alive_interval(HTTP2_KEEP_ALIVE_INTERVAL);
63
64 server
65}
66
67#[derive(Clone)]
68pub struct TimeoutAcceptor;
69
70impl<I, S> Accept<I, S> for TimeoutAcceptor {
71 type Stream = TimeoutStream<I>;
72 type Service = TimeoutService<S>;
73 type Future = Ready<std::io::Result<(Self::Stream, Self::Service)>>;
74
75 fn accept(&self, stream: I, service: S) -> Self::Future {
76 let (tx, rx) = mpsc::unbounded_channel();
77
78 let stream = TimeoutStream::new(stream, HTTP1_HEADER_READ_TIMEOUT, rx);
79 let service = TimeoutService::new(service, tx);
80
81 std::future::ready(Ok((stream, service)))
82 }
83}
84
85#[derive(Clone)]
86pub struct TimeoutService<S> {
87 inner: S,
88 sender: UnboundedSender<TimerSignal>,
89}
90
91impl<S> TimeoutService<S> {
92 fn new(inner: S, sender: UnboundedSender<TimerSignal>) -> Self {
93 Self { inner, sender }
94 }
95}
96
97impl<S, B, Request> Service<Request> for TimeoutService<S>
98where
99 S: Service<Request, Response = Response<B>>,
100{
101 type Response = Response<TimeoutBody<B>>;
102 type Error = S::Error;
103 type Future = TimeoutServiceFuture<S::Future>;
104
105 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106 self.inner.poll_ready(cx)
107 }
108
109 fn call(&mut self, req: Request) -> Self::Future {
110 let _ = self.sender.send(TimerSignal::Wait);
112
113 TimeoutServiceFuture::new(self.inner.call(req), self.sender.clone())
114 }
115}
116
117pin_project! {
118 pub struct TimeoutServiceFuture<F> {
119 #[pin]
120 inner: F,
121 sender: Option<UnboundedSender<TimerSignal>>,
122 }
123}
124
125impl<F> TimeoutServiceFuture<F> {
126 fn new(inner: F, sender: UnboundedSender<TimerSignal>) -> Self {
127 Self {
128 inner,
129 sender: Some(sender),
130 }
131 }
132}
133
134impl<F, B, E> Future for TimeoutServiceFuture<F>
135where
136 F: Future<Output = Result<Response<B>, E>>,
137{
138 type Output = Result<Response<TimeoutBody<B>>, E>;
139
140 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141 let this = self.project();
142 this.inner.poll(cx).map(|result| {
143 result.map(|response| {
144 response.map(|body| TimeoutBody::new(body, this.sender.take().expect("future polled after ready")))
145 })
146 })
147 }
148}
149
150enum TimerSignal {
151 Wait,
152 Reset,
153}
154
155pin_project! {
156 pub struct TimeoutBody<B> {
157 #[pin]
158 inner: B,
159 sender: UnboundedSender<TimerSignal>,
160 }
161}
162
163impl<B> TimeoutBody<B> {
164 fn new(inner: B, sender: UnboundedSender<TimerSignal>) -> Self {
165 Self { inner, sender }
166 }
167}
168
169impl<B: Body> Body for TimeoutBody<B> {
170 type Data = B::Data;
171 type Error = B::Error;
172
173 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
174 let this = self.project();
175 let option = ready!(this.inner.poll_frame(cx));
176
177 if option.is_none() {
178 let _ = this.sender.send(TimerSignal::Reset);
179 }
180
181 Poll::Ready(option)
182 }
183
184 fn is_end_stream(&self) -> bool {
185 let is_end_stream = self.inner.is_end_stream();
186
187 if is_end_stream {
188 let _ = self.sender.send(TimerSignal::Reset);
189 }
190
191 is_end_stream
192 }
193
194 fn size_hint(&self) -> http_body::SizeHint {
195 self.inner.size_hint()
196 }
197}
198
199pub struct TimeoutStream<IO> {
200 inner: IO,
201 sleep: Pin<Box<Sleep>>,
203 duration: Duration,
204 waiting: bool,
205 receiver: UnboundedReceiver<TimerSignal>,
206 finished: bool,
207}
208
209impl<IO> TimeoutStream<IO> {
210 fn new(inner: IO, duration: Duration, receiver: UnboundedReceiver<TimerSignal>) -> Self {
211 Self {
212 inner,
213 sleep: Box::pin(tokio::time::sleep(duration)),
214 duration,
215 waiting: false,
216 receiver,
217 finished: false,
218 }
219 }
220}
221
222impl<IO: AsyncRead + Unpin> AsyncRead for TimeoutStream<IO> {
223 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
224 if !self.finished {
225 match Pin::new(&mut self.receiver).poll_recv(cx) {
226 Poll::Ready(Some(TimerSignal::Reset)) => {
228 self.waiting = false;
229
230 let deadline = Instant::now() + self.duration;
231 self.sleep.as_mut().reset(deadline);
232 }
233 Poll::Ready(Some(TimerSignal::Wait)) => self.waiting = true,
235 Poll::Ready(None) => self.finished = true,
236 Poll::Pending => (),
237 }
238 }
239
240 if !self.waiting {
241 if let Poll::Ready(()) = self.sleep.as_mut().poll(cx) {
243 return Poll::Ready(Err(std::io::Error::new(ErrorKind::TimedOut, "request header read timed out")));
244 }
245 }
246
247 Pin::new(&mut self.inner).poll_read(cx, buf)
248 }
249}
250
251impl<IO: AsyncWrite + Unpin> AsyncWrite for TimeoutStream<IO> {
252 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
253 Pin::new(&mut self.inner).poll_write(cx, buf)
254 }
255
256 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
257 Pin::new(&mut self.inner).poll_flush(cx)
258 }
259
260 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
261 Pin::new(&mut self.inner).poll_shutdown(cx)
262 }
263
264 fn poll_write_vectored(
265 mut self: Pin<&mut Self>,
266 cx: &mut Context<'_>,
267 bufs: &[std::io::IoSlice<'_>],
268 ) -> Poll<Result<usize, std::io::Error>> {
269 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
270 }
271
272 fn is_write_vectored(&self) -> bool {
273 self.inner.is_write_vectored()
274 }
275}