1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
use crate::{extension::Reflect, SgBody, SgResponseExt};

use hyper::StatusCode;
use hyper::{Request, Response};
use hyper_rustls::HttpsConnector;
use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
use hyper_util::{
    client::legacy::{connect::HttpConnector, Client},
    rt::TokioExecutor,
};
use std::{
    collections::HashMap,
    sync::{Arc, Mutex, OnceLock},
    time::Duration,
};
use tokio_rustls::rustls::{self, client::danger::ServerCertVerifier, SignatureScheme};

#[derive(Debug, Clone)]
pub struct NoCertificateVerification {}
impl ServerCertVerifier for NoCertificateVerification {
    fn verify_server_cert(
        &self,
        _end_entity: &rustls::pki_types::CertificateDer<'_>,
        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
        _server_name: &rustls::pki_types::ServerName<'_>,
        _ocsp_response: &[u8],
        _now: rustls::pki_types::UnixTime,
    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
        Ok(rustls::client::danger::ServerCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _message: &[u8],
        _cert: &rustls::pki_types::CertificateDer<'_>,
        _dss: &rustls::DigitallySignedStruct,
    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _message: &[u8],
        _cert: &rustls::pki_types::CertificateDer<'_>,
        _dss: &rustls::DigitallySignedStruct,
    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
    }

    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
        vec![
            SignatureScheme::RSA_PKCS1_SHA256,
            SignatureScheme::RSA_PKCS1_SHA384,
            SignatureScheme::RSA_PKCS1_SHA512,
            SignatureScheme::RSA_PSS_SHA256,
            SignatureScheme::RSA_PSS_SHA384,
            SignatureScheme::RSA_PSS_SHA512,
            SignatureScheme::ECDSA_NISTP256_SHA256,
            SignatureScheme::ECDSA_NISTP384_SHA384,
            SignatureScheme::ECDSA_NISTP521_SHA512,
            SignatureScheme::ED25519,
            SignatureScheme::ED448,
        ]
    }
}

fn get_rustls_config_dangerous() -> rustls::ClientConfig {
    let store = rustls::RootCertStore::empty();
    let mut config = rustls::ClientConfig::builder().with_root_certificates(store).with_no_client_auth();
    // completely disable cert-verification
    let mut dangerous_config = rustls::ClientConfig::dangerous(&mut config);
    dangerous_config.set_certificate_verifier(Arc::new(NoCertificateVerification {}));
    config
}

pub fn get_client() -> HttpClient {
    ClientRepo::global().get_default()
}

pub struct ClientRepo {
    default: HttpClient,
    repo: Mutex<HashMap<String, HttpClient>>,
}

impl Default for ClientRepo {
    fn default() -> Self {
        let config = get_rustls_config_dangerous();
        let default = HttpClient::new(config);
        Self {
            default,
            repo: Default::default(),
        }
    }
}

static mut GLOBAL: OnceLock<ClientRepo> = OnceLock::new();
impl ClientRepo {
    pub fn get(&self, code: &str) -> Option<HttpClient> {
        self.repo.lock().expect("failed to lock client repo").get(code).cloned()
    }
    pub fn get_or_default(&self, code: &str) -> HttpClient {
        self.get(code).unwrap_or_else(|| self.default.clone())
    }
    pub fn get_default(&self) -> HttpClient {
        self.default.clone()
    }
    pub fn register(&self, code: &str, client: HttpClient) {
        self.repo.lock().expect("failed to lock client repo").insert(code.to_string(), client);
    }
    pub fn set_default(&mut self, client: HttpClient) {
        self.default = client;
    }
    pub fn global() -> &'static Self {
        unsafe { std::ptr::addr_of!(GLOBAL).cast_mut().as_mut().expect("invalid static global client repo") }.get_or_init(Default::default)
    }

    /// # Safety
    /// This function is not thread safe, it should be called before any other thread is spawned.
    pub unsafe fn set_global_default(client: HttpClient) {
        GLOBAL.get_or_init(Default::default);
        GLOBAL.get_mut().expect("global not set").set_default(client);
    }
}

pub struct SgHttpClientConfig {
    pub tls_config: rustls::ClientConfig,
}

#[derive(Debug, Clone)]
pub struct HttpClient {
    inner: Client<HttpsConnector<HttpConnector>, SgBody>,
}

impl Default for HttpClient {
    fn default() -> Self {
        Self::new(rustls::ClientConfig::builder().with_native_roots().expect("failed to init rustls config").with_no_client_auth())
    }
}

impl HttpClient {
    pub fn new(tls_config: rustls::ClientConfig) -> Self {
        HttpClient {
            inner: Client::builder(TokioExecutor::new()).build(HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http().enable_http1().enable_http2().build()),
        }
    }
    pub fn new_dangerous() -> Self {
        let config = get_rustls_config_dangerous();
        Self::new(config)
    }
    pub async fn request(&mut self, mut req: Request<SgBody>) -> Response<SgBody> {
        let reflect = req.extensions_mut().remove::<Reflect>();
        match self.inner.request(req).await.map_err(Response::bad_gateway) {
            Ok(mut response) => {
                if let Some(reflect) = reflect {
                    response.extensions_mut().extend(reflect.into_inner());
                }
                response.map(SgBody::new)
            }
            Err(err) => err,
        }
    }
    pub async fn request_timeout(&mut self, req: Request<SgBody>, timeout: Duration) -> Response<SgBody> {
        let fut = self.request(req);
        let resp = tokio::time::timeout(timeout, fut).await;
        match resp {
            Ok(resp) => resp,
            Err(_) => Response::with_code_message(StatusCode::GATEWAY_TIMEOUT, "request timeout"),
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;
    #[tokio::test]
    async fn test_client() {
        let mut client = get_client();
        let req = Request::builder().uri("https://www.baidu.com").body(SgBody::empty()).unwrap();
        let resp = client.request(req).await;
        let (part, body) = resp.into_parts();
        let body = body.dump().await.unwrap();
        let dumped = body.get_dumped().expect("no body");
        println!("{part:?}, {}", String::from_utf8_lossy(dumped));
    }
}