ws_tool/
connector.rs

1use http;
2use http::Uri;
3use crate::{errors::WsError, protocol::Mode};
4
5/// get websocket scheme
6pub fn get_scheme(uri: &http::Uri) -> Result<Mode, WsError> {
7    match uri.scheme_str().unwrap_or("ws").to_lowercase().as_str() {
8        "ws" => Ok(Mode::WS),
9        "wss" => Ok(Mode::WSS),
10        s => Err(WsError::InvalidUri(format!("unknown scheme {s}"))),
11    }
12}
13
14/// get host from uri
15pub fn get_host(uri: &Uri) -> Result<&str, WsError> {
16    uri.host()
17        .ok_or_else(|| WsError::InvalidUri(format!("can not find host {}", uri)))
18}
19
20#[cfg(feature = "sync")]
21mod blocking {
22    use crate::errors::WsError;
23    use http;
24    use std::net::TcpStream;
25
26    use super::{get_host, get_scheme};
27
28    /// performance tcp connection
29    pub fn tcp_connect(uri: &http::Uri) -> Result<TcpStream, WsError> {
30        let mode = get_scheme(uri)?;
31        let host = get_host(uri)?;
32        let port = uri.port_u16().unwrap_or_else(|| mode.default_port());
33        let stream = TcpStream::connect((host, port)).map_err(|e| {
34            WsError::ConnectionFailed(format!("failed to create tcp connection {e}"))
35        })?;
36        Ok(stream)
37    }
38
39    // #[cfg(feature = "sync_tls_rustls")]
40    // impl<S: std::io::Read + std::io::Write> crate::codec::Split for rustls_connector::TlsStream<S> {
41    //     type R = tokio::io::ReadHalf<BufStream<S>>;
42    //     type W = tokio::io::WriteHalf<BufStream<S>>;
43    //     fn split(self) -> (Self::R, Self::W) {
44    //         tokio::io::split(self)
45    //     }
46    // }
47
48    #[cfg(feature = "sync_tls_rustls")]
49    /// start tls session
50    pub fn wrap_rustls<
51        S: std::io::Read + std::io::Write + Sync + Send + std::fmt::Debug + 'static,
52    >(
53        stream: S,
54        host: &str,
55        certs: Vec<std::path::PathBuf>,
56    ) -> Result<rustls_connector::TlsStream<S>, WsError> {
57        use std::io::BufReader;
58
59        let mut config = rustls_connector::RustlsConnectorConfig::new_with_webpki_roots_certs();
60        let mut cert_data = vec![];
61        for cert_path in certs.iter() {
62            let mut pem = std::fs::File::open(cert_path).map_err(|_| {
63                WsError::CertFileNotFound(cert_path.to_str().unwrap_or_default().to_string())
64            })?;
65            let mut cert = BufReader::new(&mut pem);
66            let certs = rustls_pemfile::certs(&mut cert)
67                .map_err(|e| WsError::LoadCertFailed(e.to_string()))?;
68            cert_data.extend_from_slice(&certs);
69        }
70        config.add_parsable_certificates(&cert_data);
71        let connector = config.connector_with_no_client_auth();
72        let tls_stream = connector
73            .connect(host, stream)
74            .map_err(|e| WsError::ConnectionFailed(e.to_string()))?;
75        tracing::debug!("tls connection established");
76        Ok(tls_stream)
77    }
78
79    // #[cfg(feature = "sync_tls_native")]
80    // impl<S: std::io::Read + std::io::Write> crate::codec::Split for rustls_connector::TlsStream<S> {
81    //     type R = tokio::io::ReadHalf<BufStream<S>>;
82    //     type W = tokio::io::WriteHalf<BufStream<S>>;
83    //     fn split(self) -> (Self::R, Self::W) {
84    //         tokio::io::split(self)
85    //     }
86    // }
87
88    #[cfg(feature = "sync_tls_native")]
89    /// start tls session
90    pub fn wrap_native_tls<S: std::io::Read + std::io::Write>(
91        stream: S,
92        host: &str,
93        certs: Vec<std::path::PathBuf>,
94    ) -> Result<native_tls::TlsStream<S>, WsError> {
95        let mut builder = native_tls::TlsConnector::builder();
96        for cert_path in certs.iter() {
97            let mut pem = std::fs::File::open(cert_path).map_err(|_| {
98                WsError::CertFileNotFound(cert_path.to_str().unwrap_or_default().to_string())
99            })?;
100            let mut data = vec![];
101            if let Err(e) = std::io::Read::read_to_end(&mut pem, &mut data) {
102                tracing::error!(
103                    "failed to read cert file {} {}",
104                    cert_path.display(),
105                    e.to_string()
106                );
107                continue;
108            }
109            match native_tls::Certificate::from_der(&data) {
110                Ok(cert) => {
111                    builder.add_root_certificate(cert);
112                }
113                Err(e) => {
114                    tracing::error!(
115                        "invalid cert file {} {}",
116                        cert_path.display(),
117                        e.to_string()
118                    );
119                    continue;
120                }
121            }
122        }
123        let connector = builder.build().unwrap();
124        let tls_stream = connector
125            .connect(host, stream)
126            .map_err(|_| WsError::ConnectionFailed("tls connect failed".into()))?;
127        tracing::debug!("tls connection established");
128        Ok(tls_stream)
129    }
130}
131
132#[cfg(feature = "sync")]
133pub use blocking::*;
134
135#[cfg(feature = "async")]
136mod non_blocking {
137    use http::Uri;
138    use tokio::net::TcpStream;
139
140    use crate::errors::WsError;
141
142    use super::{get_host, get_scheme};
143
144    /// performance tcp connection
145    pub async fn async_tcp_connect(uri: &Uri) -> Result<TcpStream, WsError> {
146        let mode = get_scheme(uri)?;
147        let host = get_host(uri)?;
148        let port = uri.port_u16().unwrap_or_else(|| mode.default_port());
149
150        TcpStream::connect((host, port))
151            .await
152            .map_err(|e| WsError::ConnectionFailed(format!("failed to create tcp connection {e}")))
153    }
154
155    #[cfg(feature = "async_tls_rustls")]
156    impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> crate::codec::Split
157        for tokio_rustls::client::TlsStream<S>
158    {
159        type R = tokio::io::ReadHalf<tokio_rustls::client::TlsStream<S>>;
160        type W = tokio::io::WriteHalf<tokio_rustls::client::TlsStream<S>>;
161        fn split(self) -> (Self::R, Self::W) {
162            tokio::io::split(self)
163        }
164    }
165
166    #[cfg(feature = "async_tls_rustls")]
167    impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> crate::codec::Split
168        for tokio_rustls::server::TlsStream<S>
169    {
170        type R = tokio::io::ReadHalf<tokio_rustls::server::TlsStream<S>>;
171        type W = tokio::io::WriteHalf<tokio_rustls::server::TlsStream<S>>;
172        fn split(self) -> (Self::R, Self::W) {
173            tokio::io::split(self)
174        }
175    }
176
177    #[cfg(feature = "async_tls_rustls")]
178    /// async version of starting tls session
179    pub async fn async_wrap_rustls<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin>(
180        stream: S,
181        host: &str,
182        certs: Vec<std::path::PathBuf>,
183    ) -> Result<tokio_rustls::client::TlsStream<S>, WsError> {
184        use std::io::BufReader;
185
186        let mut root_store = rustls_connector::rustls::RootCertStore::empty();
187        root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
188            rustls_connector::rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
189                ta.subject,
190                ta.spki,
191                ta.name_constraints,
192            )
193        }));
194        let mut trust_anchors = vec![];
195        for cert_path in certs.iter() {
196            let mut pem = std::fs::File::open(cert_path).map_err(|_| {
197                WsError::CertFileNotFound(cert_path.to_str().unwrap_or_default().to_string())
198            })?;
199            let mut cert = BufReader::new(&mut pem);
200            let certs = rustls_pemfile::certs(&mut cert)
201                .map_err(|e| WsError::LoadCertFailed(e.to_string()))?;
202            for item in certs {
203                let ta = webpki::TrustAnchor::try_from_cert_der(&item[..])
204                    .map_err(|e| WsError::LoadCertFailed(e.to_string()))?;
205                let anchor =
206                    rustls_connector::rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
207                        ta.subject,
208                        ta.spki,
209                        ta.name_constraints,
210                    );
211                trust_anchors.push(anchor);
212            }
213        }
214        root_store.add_server_trust_anchors(trust_anchors.into_iter());
215        let config = rustls_connector::rustls::ClientConfig::builder()
216            .with_safe_defaults()
217            .with_root_certificates(root_store)
218            .with_no_client_auth();
219        let domain = tokio_rustls::rustls::ServerName::try_from(host)
220            .map_err(|e| WsError::TlsDnsFailed(e.to_string()))?;
221        let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
222        let tls_stream = connector
223            .connect(domain, stream)
224            .await
225            .map_err(|e| WsError::ConnectionFailed(e.to_string()))?;
226        tracing::debug!("tls connection established");
227        Ok(tls_stream)
228    }
229
230    #[cfg(feature = "async_tls_native")]
231    impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> crate::codec::Split
232        for tokio_native_tls::TlsStream<S>
233    {
234        type R = tokio::io::ReadHalf<tokio_native_tls::TlsStream<S>>;
235        type W = tokio::io::WriteHalf<tokio_native_tls::TlsStream<S>>;
236        fn split(self) -> (Self::R, Self::W) {
237            tokio::io::split(self)
238        }
239    }
240
241    #[cfg(feature = "async_tls_native")]
242    /// start tls session
243    pub async fn async_wrap_native_tls<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin>(
244        stream: S,
245        host: &str,
246        certs: Vec<std::path::PathBuf>,
247    ) -> Result<tokio_native_tls::TlsStream<S>, WsError> {
248        let mut builder = native_tls::TlsConnector::builder();
249        for cert_path in certs.iter() {
250            let mut pem = std::fs::File::open(cert_path).map_err(|_| {
251                WsError::CertFileNotFound(cert_path.to_str().unwrap_or_default().to_string())
252            })?;
253            let mut data = vec![];
254            if let Err(e) = std::io::Read::read_to_end(&mut pem, &mut data) {
255                tracing::error!(
256                    "failed to read cert file {} {}",
257                    cert_path.display(),
258                    e.to_string()
259                );
260                continue;
261            }
262            match native_tls::Certificate::from_der(&data) {
263                Ok(cert) => {
264                    builder.add_root_certificate(cert);
265                }
266                Err(e) => {
267                    tracing::error!(
268                        "invalid cert file {} {}",
269                        cert_path.display(),
270                        e.to_string()
271                    );
272                    continue;
273                }
274            }
275        }
276        let connector = builder.build().unwrap();
277        let connector = tokio_native_tls::TlsConnector::from(connector);
278        let tls_stream = connector
279            .connect(host, stream)
280            .await
281            .map_err(|e| WsError::ConnectionFailed(e.to_string()))?;
282        tracing::debug!("tls connection established");
283        Ok(tls_stream)
284    }
285}
286
287#[cfg(feature = "async")]
288pub use non_blocking::*;