spacegate_kernel/
backend_service.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use std::convert::Infallible;
use std::sync::Arc;

use futures_util::future::BoxFuture;
use futures_util::Future;
use hyper::{header::UPGRADE, Request, Response, StatusCode};
use tracing::instrument;

use crate::backend_service::http_client_service::get_client;
use crate::helper_layers::map_future::MapFuture;
use crate::utils::x_forwarded_for;
use crate::BoxError;
use crate::SgBody;
use crate::SgResponse;
use crate::SgResponseExt;

pub mod echo;
pub mod http_client_service;
pub mod static_file_service;
pub mod ws_client_service;
pub trait SharedHyperService<R>: hyper::service::Service<R> + Send + Sync + 'static {}

impl<R, T> SharedHyperService<R> for T where T: hyper::service::Service<R> + Send + Sync + 'static {}
/// a service that can be shared between threads
pub struct ArcHyperService {
    pub shared: Arc<
        dyn SharedHyperService<Request<SgBody>, Response = Response<SgBody>, Error = Infallible, Future = BoxFuture<'static, Result<Response<SgBody>, Infallible>>> + Send + Sync,
    >,
}

impl std::fmt::Debug for ArcHyperService {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ArcHyperService").finish()
    }
}

impl Clone for ArcHyperService {
    fn clone(&self) -> Self {
        Self { shared: self.shared.clone() }
    }
}

impl ArcHyperService {
    pub fn new<T>(service: T) -> Self
    where
        T: SharedHyperService<Request<SgBody>, Response = Response<SgBody>, Error = Infallible> + Send + Sync + 'static,
        T::Future: Future<Output = Result<Response<SgBody>, Infallible>> + 'static + Send,
    {
        let map_fut = MapFuture::new(service, |fut| Box::pin(fut) as _);
        Self { shared: Arc::new(map_fut) }
    }
}

impl hyper::service::Service<Request<SgBody>> for ArcHyperService {
    type Response = Response<SgBody>;
    type Error = Infallible;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    #[inline]
    fn call(&self, req: Request<SgBody>) -> Self::Future {
        Box::pin(self.shared.call(req))
    }
}

/// Http backend service
///
/// This function could be a bottom layer of a http router, it will handle http and websocket request.
///
/// This can handle both websocket connection and http request.
///
/// # Errors
/// 1. Fail to collect body chunks
/// 2. Fail to upgrade
pub async fn http_backend_service_inner(mut req: Request<SgBody>) -> Result<SgResponse, BoxError> {
    tracing::trace!(elapsed = ?req.extensions().get::<crate::extension::EnterTime>().map(crate::extension::EnterTime::elapsed), "start a backend request");
    x_forwarded_for(&mut req)?;
    let mut client = get_client();
    let response = if req.headers().get(UPGRADE).is_some_and(|upgrade| upgrade.as_bytes().eq_ignore_ascii_case(b"websocket")) {
        // dump request
        let (part, body) = req.into_parts();
        let body = body.dump().await?;
        let req = Request::from_parts(part, body);

        // forward request
        let resp = client.request(req.clone()).await;

        // dump response
        let (part, body) = resp.into_parts();
        let body = body.dump().await?;
        let resp = Response::from_parts(part, body);

        let req_for_upgrade = req.clone();
        let resp_for_upgrade = resp.clone();

        // create forward task
        tokio::task::spawn(async move {
            // update both side
            let (s, c) = futures_util::join!(hyper::upgrade::on(req_for_upgrade), hyper::upgrade::on(resp_for_upgrade));
            let upgrade_as_server = s?;
            let upgrade_as_client = c?;
            // start a websocket forward
            ws_client_service::tcp_transfer(upgrade_as_server, upgrade_as_client).await?;
            <Result<(), BoxError>>::Ok(())
        });
        // return response to client
        resp
    } else {
        client.request(req).await
    };
    Ok(response)
}

#[instrument]
pub async fn http_backend_service(req: Request<SgBody>) -> Result<Response<SgBody>, Infallible> {
    match http_backend_service_inner(req).await {
        Ok(resp) => Ok(resp),
        Err(err) => Ok(Response::with_code_message(StatusCode::BAD_GATEWAY, format!("[Sg.Client] Client error: {err}"))),
    }
}

#[inline]
pub fn get_http_backend_service() -> ArcHyperService {
    ArcHyperService::new(hyper::service::service_fn(http_backend_service))
}

#[cold]
#[inline]
pub fn get_echo_service() -> ArcHyperService {
    ArcHyperService::new(hyper::service::service_fn(echo::echo))
}