spacegate_kernel/
service.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use futures_util::future::BoxFuture;
4use hyper::{body::Incoming, Request, Response};
5use hyper_util::rt::TokioIo;
6use tokio::net::TcpStream;
7use tokio_rustls::rustls;
8
9use crate::{
10    extension::{EnterTime, PeerAddr, Reflect},
11    ArcHyperService, BoxResult, SgBody,
12};
13
14pub mod http_route;
15
16pub mod http_gateway;
17
18pub trait TcpService: 'static + Send + Sync {
19    fn protocol_name(&self) -> &str;
20    fn sniff_peek_size(&self) -> usize;
21    fn sniff(&self, peek_buf: &[u8]) -> bool;
22    fn handle(&self, stream: TcpStream, peer: SocketAddr) -> BoxFuture<'static, BoxResult<()>>;
23}
24type ConnectionBuilder = hyper_util::server::conn::auto::Builder<hyper_util::rt::TokioExecutor>;
25
26#[derive(Debug)]
27pub struct Http {
28    inner_service: ArcHyperService,
29    connection_builder: ConnectionBuilder,
30}
31
32impl Http {
33    pub fn new(service: ArcHyperService) -> Self {
34        Self {
35            inner_service: service,
36            connection_builder: ConnectionBuilder::new(Default::default()),
37        }
38    }
39}
40
41impl TcpService for Http {
42    fn protocol_name(&self) -> &str {
43        "http"
44    }
45    fn sniff_peek_size(&self) -> usize {
46        14
47    }
48    fn sniff(&self, peeked: &[u8]) -> bool {
49        peeked.starts_with(b"GET")
50            || peeked.starts_with(b"HEAD")
51            || peeked.starts_with(b"POST")
52            || peeked.starts_with(b"PUT")
53            || peeked.starts_with(b"DELETE")
54            || peeked.starts_with(b"CONNECT")
55            || peeked.starts_with(b"OPTIONS")
56            || peeked.starts_with(b"TRACE")
57            || peeked.starts_with(b"PATCH")
58            || peeked.starts_with(b"PRI * HTTP/2.0")
59    }
60    fn handle(&self, stream: TcpStream, peer: SocketAddr) -> BoxFuture<'static, BoxResult<()>> {
61        let io = TokioIo::new(stream);
62        let service = HyperServiceAdapter::new(self.inner_service.clone(), peer);
63        let builder = self.connection_builder.clone();
64        Box::pin(async move {
65            let conn = builder.serve_connection_with_upgrades(io, service);
66            conn.await
67        })
68    }
69}
70#[derive(Debug)]
71pub struct Https {
72    inner_service: ArcHyperService,
73    tls_config: Arc<rustls::ServerConfig>,
74    connection_builder: ConnectionBuilder,
75}
76
77impl Https {
78    pub fn new(service: ArcHyperService, tls_config: rustls::ServerConfig) -> Self {
79        Self {
80            inner_service: service,
81            tls_config: Arc::new(tls_config),
82            connection_builder: ConnectionBuilder::new(Default::default()),
83        }
84    }
85}
86
87impl TcpService for Https {
88    fn protocol_name(&self) -> &str {
89        "https"
90    }
91    fn sniff_peek_size(&self) -> usize {
92        5
93    }
94    fn sniff(&self, peeked: &[u8]) -> bool {
95        peeked.starts_with(b"\x16\x03")
96    }
97    fn handle(&self, stream: TcpStream, peer: SocketAddr) -> BoxFuture<'static, BoxResult<()>> {
98        let service = HyperServiceAdapter::new(self.inner_service.clone(), peer);
99        let builder = self.connection_builder.clone();
100        let connector = tokio_rustls::TlsAcceptor::from(self.tls_config.clone());
101        Box::pin(async move {
102            let accepted = connector.accept(stream).await?;
103            let conn = builder.serve_connection_with_upgrades(TokioIo::new(accepted), service);
104            conn.await
105        })
106    }
107}
108
109#[derive(Clone, Debug)]
110pub struct HyperServiceAdapter<S>
111where
112    S: hyper::service::Service<Request<SgBody>, Error = Infallible, Response = Response<SgBody>> + Clone + Send + 'static,
113    S::Future: Send + 'static,
114{
115    service: S,
116    peer: SocketAddr,
117}
118
119impl<S> HyperServiceAdapter<S>
120where
121    S: hyper::service::Service<Request<SgBody>, Error = Infallible, Response = Response<SgBody>> + Clone + Send + 'static,
122    S::Future: Send + 'static,
123{
124    pub fn new(service: S, peer: SocketAddr) -> Self {
125        Self { service, peer }
126    }
127}
128
129impl<S> hyper::service::Service<Request<Incoming>> for HyperServiceAdapter<S>
130where
131    S: hyper::service::Service<Request<SgBody>, Error = Infallible, Response = Response<SgBody>> + Clone + Send + 'static,
132    S::Future: Send + 'static,
133{
134    type Response = Response<SgBody>;
135    type Error = Infallible;
136    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
137
138    #[inline]
139    fn call(&self, mut req: Request<Incoming>) -> Self::Future {
140        req.extensions_mut().insert(self.peer);
141        // here we will clone underlying service,
142        // so it's important that underlying service is cheap to clone.
143        // here, the service are likely to be a `ArcHyperService` so it's ok
144        // but if underlying service is big, it will be expensive to clone.
145        // especially the router is big and the too many plugins are installed.
146        // so we should avoid that
147        let enter_time = EnterTime::new();
148        let service = self.service.clone();
149        let mut req = req.map(SgBody::new);
150        let mut reflect = Reflect::default();
151        // let method = req.method().clone();
152        reflect.insert(enter_time);
153        req.extensions_mut().insert(reflect);
154        req.extensions_mut().insert(PeerAddr(self.peer));
155        req.extensions_mut().insert(enter_time);
156        Box::pin(async move {
157            let resp = service.call(req).await.expect("infallible");
158            // if method != hyper::Method::HEAD && method != hyper::Method::OPTIONS && method != hyper::Method::CONNECT {
159            //     with_length_or_chunked(&mut resp);
160            // }
161            let status = resp.status();
162            if status.is_server_error() {
163                tracing::warn!(status = ?status, headers = ?resp.headers(), "server error response");
164            } else if status.is_client_error() {
165                tracing::debug!(status = ?status, headers = ?resp.headers(), "client error response");
166            } else if status.is_success() {
167                tracing::trace!(status = ?status, headers = ?resp.headers(), "success response");
168            }
169            tracing::trace!(latency = ?enter_time.elapsed(), "request finished");
170            Ok(resp)
171        })
172    }
173}
174
175impl ArcHyperService {
176    pub fn http(self) -> Http {
177        Http::new(self)
178    }
179    pub fn https(self, tls_config: rustls::ServerConfig) -> Https {
180        Https::new(self, tls_config)
181    }
182}