spacegate_kernel/helper_layers/
timeout.rs1use std::{convert::Infallible, time::Duration};
2
3use crate::SgBody;
4use futures_util::Future;
5use hyper::{Request, Response};
6use tokio::time::Sleep;
7use tower_layer::Layer;
8#[derive(Clone, Debug)]
9pub struct TimeoutLayer {
10 pub timeout: Duration,
12 pub timeout_response: hyper::body::Bytes,
13}
14
15impl<S> Layer<S> for TimeoutLayer {
16 type Service = Timeout<S>;
17
18 fn layer(&self, inner: S) -> Self::Service {
19 Timeout {
20 inner,
21 timeout: self.timeout,
22 timeout_response: self.timeout_response.clone(),
23 }
24 }
25}
26
27#[derive(Clone, Debug)]
28pub struct Timeout<S> {
29 inner: S,
30 timeout: Duration,
31 timeout_response: hyper::body::Bytes,
32}
33
34impl TimeoutLayer {
35 pub fn new(timeout: Duration) -> Self {
36 Self {
37 timeout,
38 timeout_response: hyper::body::Bytes::default(),
39 }
40 }
41 pub fn set_timeout(&mut self, timeout: Duration) {
42 self.timeout = timeout;
43 }
44}
45
46impl<S> Timeout<S> {
47 pub fn new(timeout: Duration, timeout_response: hyper::body::Bytes, inner: S) -> Self {
48 Self { inner, timeout, timeout_response }
49 }
50}
51
52impl<S> hyper::service::Service<Request<SgBody>> for Timeout<S>
53where
54 S: hyper::service::Service<Request<SgBody>, Response = Response<SgBody>, Error = Infallible> + Send + 'static,
55 <S as hyper::service::Service<Request<SgBody>>>::Future: std::marker::Send,
56{
57 type Response = Response<SgBody>;
58
59 type Error = Infallible;
60
61 type Future = TimeoutFuture<S::Future>;
62
63 fn call(&self, req: Request<SgBody>) -> Self::Future {
64 TimeoutFuture {
65 inner: self.inner.call(req),
66 timeout: tokio::time::sleep(self.timeout),
67 timeout_response: self.timeout_response.clone(),
68 }
69 }
70}
71
72pin_project_lite::pin_project! {
73 pub struct TimeoutFuture<F> {
74 #[pin]
75 inner: F,
76 #[pin]
77 timeout: Sleep,
78 timeout_response: hyper::body::Bytes,
79 }
80}
81
82impl<F> Future for TimeoutFuture<F>
83where
84 F: Future<Output = Result<Response<SgBody>, Infallible>> + Send + 'static,
85{
86 type Output = Result<Response<SgBody>, Infallible>;
87
88 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
89 let this = self.project();
90 if this.timeout.poll(cx).is_ready() {
91 let response = Response::builder().status(hyper::StatusCode::GATEWAY_TIMEOUT).body(SgBody::full(this.timeout_response.clone())).expect("invalid response");
92 return std::task::Poll::Ready(Ok(response));
93 }
94 this.inner.poll(cx)
95 }
96}