spacegate_kernel/backend_service/
http_client_service.rs

1use crate::{extension::Reflect, SgBody, SgResponseExt};
2
3use hyper::StatusCode;
4use hyper::{Request, Response};
5use hyper_rustls::HttpsConnector;
6use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
7use hyper_util::{
8    client::legacy::{connect::HttpConnector, Client},
9    rt::TokioExecutor,
10};
11use std::{
12    collections::HashMap,
13    sync::{Arc, Mutex, OnceLock},
14    time::Duration,
15};
16use tokio_rustls::rustls::{self, client::danger::ServerCertVerifier, SignatureScheme};
17
18#[derive(Debug, Clone)]
19pub struct NoCertificateVerification {}
20impl ServerCertVerifier for NoCertificateVerification {
21    fn verify_server_cert(
22        &self,
23        _end_entity: &rustls::pki_types::CertificateDer<'_>,
24        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
25        _server_name: &rustls::pki_types::ServerName<'_>,
26        _ocsp_response: &[u8],
27        _now: rustls::pki_types::UnixTime,
28    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
29        Ok(rustls::client::danger::ServerCertVerified::assertion())
30    }
31
32    fn verify_tls12_signature(
33        &self,
34        _message: &[u8],
35        _cert: &rustls::pki_types::CertificateDer<'_>,
36        _dss: &rustls::DigitallySignedStruct,
37    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
38        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
39    }
40
41    fn verify_tls13_signature(
42        &self,
43        _message: &[u8],
44        _cert: &rustls::pki_types::CertificateDer<'_>,
45        _dss: &rustls::DigitallySignedStruct,
46    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
47        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
48    }
49
50    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
51        vec![
52            SignatureScheme::RSA_PKCS1_SHA256,
53            SignatureScheme::RSA_PKCS1_SHA384,
54            SignatureScheme::RSA_PKCS1_SHA512,
55            SignatureScheme::RSA_PSS_SHA256,
56            SignatureScheme::RSA_PSS_SHA384,
57            SignatureScheme::RSA_PSS_SHA512,
58            SignatureScheme::ECDSA_NISTP256_SHA256,
59            SignatureScheme::ECDSA_NISTP384_SHA384,
60            SignatureScheme::ECDSA_NISTP521_SHA512,
61            SignatureScheme::ED25519,
62            SignatureScheme::ED448,
63        ]
64    }
65}
66
67fn get_rustls_config_dangerous() -> rustls::ClientConfig {
68    let store = rustls::RootCertStore::empty();
69    let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
70    let mut config = rustls::ClientConfig::builder().with_root_certificates(store).with_no_client_auth();
71    // completely disable cert-verification
72    let mut dangerous_config = rustls::ClientConfig::dangerous(&mut config);
73    dangerous_config.set_certificate_verifier(Arc::new(NoCertificateVerification {}));
74    config
75}
76
77pub fn get_client() -> HttpClient {
78    ClientRepo::global().get_default()
79}
80
81#[derive(Debug)]
82pub struct ClientRepo {
83    default: HttpClient,
84    repo: Mutex<HashMap<String, HttpClient>>,
85}
86
87impl Default for ClientRepo {
88    fn default() -> Self {
89        let config = get_rustls_config_dangerous();
90        let default = HttpClient::new(config);
91        Self {
92            default,
93            repo: Default::default(),
94        }
95    }
96}
97
98static mut GLOBAL: OnceLock<ClientRepo> = OnceLock::new();
99impl ClientRepo {
100    pub fn get(&self, code: &str) -> Option<HttpClient> {
101        self.repo.lock().expect("failed to lock client repo").get(code).cloned()
102    }
103    pub fn get_or_default(&self, code: &str) -> HttpClient {
104        self.get(code).unwrap_or_else(|| self.default.clone())
105    }
106    pub fn get_default(&self) -> HttpClient {
107        self.default.clone()
108    }
109    pub fn register(&self, code: &str, client: HttpClient) {
110        self.repo.lock().expect("failed to lock client repo").insert(code.to_string(), client);
111    }
112    pub fn set_default(&mut self, client: HttpClient) {
113        self.default = client;
114    }
115    pub fn global() -> &'static Self {
116        unsafe { std::ptr::addr_of!(GLOBAL).cast_mut().as_mut().expect("invalid static global client repo") }.get_or_init(Default::default)
117    }
118
119    /// # Safety
120    /// This function is not thread safe, it should be called before any other thread is spawned.
121    pub unsafe fn set_global_default(client: HttpClient) {
122        GLOBAL.get_or_init(Default::default);
123        GLOBAL.get_mut().expect("global not set").set_default(client);
124    }
125}
126#[derive(Debug)]
127pub struct SgHttpClientConfig {
128    pub tls_config: rustls::ClientConfig,
129}
130
131#[derive(Debug, Clone)]
132pub struct HttpClient {
133    inner: Client<HttpsConnector<HttpConnector>, SgBody>,
134}
135
136impl Default for HttpClient {
137    fn default() -> Self {
138        Self::new(rustls::ClientConfig::builder().with_native_roots().expect("failed to init rustls config").with_no_client_auth())
139    }
140}
141
142impl HttpClient {
143    pub fn new(tls_config: rustls::ClientConfig) -> Self {
144        HttpClient {
145            inner: Client::builder(TokioExecutor::new()).build(HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http().enable_http1().enable_http2().build()),
146        }
147    }
148    pub fn new_h1_only(tls_config: rustls::ClientConfig) -> Self {
149        HttpClient {
150            inner: Client::builder(TokioExecutor::new()).build(HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http().enable_http1().build()),
151        }
152    }
153    pub fn new_dangerous() -> Self {
154        let config = get_rustls_config_dangerous();
155        Self::new(config)
156    }
157    pub fn new_dangerous_h1_only() -> Self {
158        let config = get_rustls_config_dangerous();
159        Self::new_h1_only(config)
160    }
161    pub async fn request(&mut self, mut req: Request<SgBody>) -> Response<SgBody> {
162        let reflect = req.extensions_mut().remove::<Reflect>();
163        match self.inner.request(req).await.map_err(Response::bad_gateway) {
164            Ok(mut response) => {
165                if let Some(reflect) = reflect {
166                    response.extensions_mut().extend(reflect.into_inner());
167                }
168                response.map(SgBody::new)
169            }
170            Err(err) => err,
171        }
172    }
173    pub async fn request_timeout(&mut self, req: Request<SgBody>, timeout: Duration) -> Response<SgBody> {
174        let fut = self.request(req);
175        let resp = tokio::time::timeout(timeout, fut).await;
176        match resp {
177            Ok(resp) => resp,
178            Err(_) => Response::with_code_message(StatusCode::GATEWAY_TIMEOUT, "request timeout"),
179        }
180    }
181}
182
183#[cfg(test)]
184mod test {
185    use super::*;
186    #[tokio::test]
187    async fn test_client() {
188        let mut client = get_client();
189        let req = Request::builder().uri("https://www.baidu.com").body(SgBody::empty()).unwrap();
190        let resp = client.request(req).await;
191        let (part, body) = resp.into_parts();
192        let body = body.dump().await.unwrap();
193        let dumped = body.get_dumped().expect("no body");
194        println!("{part:?}, {}", String::from_utf8_lossy(dumped));
195    }
196}