spacegate_kernel/helper_layers/
timeout.rs

1use std::{convert::Infallible, time::Duration};
2
3use crate::SgBody;
4use futures_util::Future;
5use hyper::{Request, Response};
6use tokio::time::Sleep;
7use tower_layer::Layer;
8#[derive(Clone, Debug)]
9pub struct TimeoutLayer {
10    /// timeout duration
11    pub timeout: Duration,
12    pub timeout_response: hyper::body::Bytes,
13}
14
15impl<S> Layer<S> for TimeoutLayer {
16    type Service = Timeout<S>;
17
18    fn layer(&self, inner: S) -> Self::Service {
19        Timeout {
20            inner,
21            timeout: self.timeout,
22            timeout_response: self.timeout_response.clone(),
23        }
24    }
25}
26
27#[derive(Clone, Debug)]
28pub struct Timeout<S> {
29    inner: S,
30    timeout: Duration,
31    timeout_response: hyper::body::Bytes,
32}
33
34impl TimeoutLayer {
35    pub fn new(timeout: Duration) -> Self {
36        Self {
37            timeout,
38            timeout_response: hyper::body::Bytes::default(),
39        }
40    }
41    pub fn set_timeout(&mut self, timeout: Duration) {
42        self.timeout = timeout;
43    }
44}
45
46impl<S> Timeout<S> {
47    pub fn new(timeout: Duration, timeout_response: hyper::body::Bytes, inner: S) -> Self {
48        Self { inner, timeout, timeout_response }
49    }
50}
51
52impl<S> hyper::service::Service<Request<SgBody>> for Timeout<S>
53where
54    S: hyper::service::Service<Request<SgBody>, Response = Response<SgBody>, Error = Infallible> + Send + 'static,
55    <S as hyper::service::Service<Request<SgBody>>>::Future: std::marker::Send,
56{
57    type Response = Response<SgBody>;
58
59    type Error = Infallible;
60
61    type Future = TimeoutFuture<S::Future>;
62
63    fn call(&self, req: Request<SgBody>) -> Self::Future {
64        TimeoutFuture {
65            inner: self.inner.call(req),
66            timeout: tokio::time::sleep(self.timeout),
67            timeout_response: self.timeout_response.clone(),
68        }
69    }
70}
71
72pin_project_lite::pin_project! {
73    pub struct TimeoutFuture<F> {
74        #[pin]
75        inner: F,
76        #[pin]
77        timeout: Sleep,
78        timeout_response: hyper::body::Bytes,
79    }
80}
81
82impl<F> Future for TimeoutFuture<F>
83where
84    F: Future<Output = Result<Response<SgBody>, Infallible>> + Send + 'static,
85{
86    type Output = Result<Response<SgBody>, Infallible>;
87
88    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
89        let this = self.project();
90        if this.timeout.poll(cx).is_ready() {
91            let response = Response::builder().status(hyper::StatusCode::GATEWAY_TIMEOUT).body(SgBody::full(this.timeout_response.clone())).expect("invalid response");
92            return std::task::Poll::Ready(Ok(response));
93        }
94        this.inner.poll(cx)
95    }
96}