1use crate::errors::Result;
2use crate::proxy::{Proxy, ProxySocket};
3use crate::socket::{Socket, StreamWrapper};
4#[cfg(feature = "tls")]
5use crate::tls::{self, Certificate, CustomTlsConnector, Identity};
6use socket2::Socket as RawSocket;
7use socket2::{Domain, Protocol, Type};
8use std::net::SocketAddr;
9use std::time::Duration;
10use tokio::net::TcpSocket;
11
12#[derive(Clone)]
14pub struct ConnectorBuilder {
15 read_timeout: Option<Duration>,
16 write_timeout: Option<Duration>,
17 connect_timeout: Option<Duration>,
18 nodelay: bool,
19 keepalive: bool,
20 proxy: Option<Proxy>,
21 #[cfg(feature = "tls")]
22 tls_config: TlsConfig,
23 #[cfg(feature = "tls")]
24 custom_tls_connector: Option<std::sync::Arc<dyn CustomTlsConnector>>,
25}
26
27impl Default for ConnectorBuilder {
28 fn default() -> Self {
29 Self {
30 read_timeout: Some(Duration::from_secs(30)),
31 write_timeout: Some(Duration::from_secs(30)),
32 connect_timeout: Some(Duration::from_secs(10)),
33 nodelay: false,
34 keepalive: false,
35 proxy: None,
36 #[cfg(feature = "tls")]
37 tls_config: TlsConfig::default(),
38 #[cfg(feature = "tls")]
39 custom_tls_connector: None,
40 }
41 }
42}
43
44#[cfg(feature = "tls")]
45#[derive(Clone)]
47pub struct TlsConfig {
48 #[cfg(feature = "http2")]
49 pub http2: bool,
50 pub hostname_verification: bool,
51 pub certs_verification: bool,
52 pub min_tls_version: Option<tls::Version>,
53 pub max_tls_version: Option<tls::Version>,
54 pub tls_sni: bool,
55 pub identity: Option<Identity>,
56 pub certificate: Vec<Certificate>,
57}
58#[cfg(feature = "rustls")]
59impl TlsConfig {
60 fn custom(
61 &self,
62 connect_timeout: Option<Duration>,
63 ) -> Result<std::sync::Arc<dyn CustomTlsConnector>> {
64 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
65 for cert in self.certificate.clone() {
66 cert.add_to_tls(&mut root_cert_store)?;
67 }
68 let certs = rustls_native_certs::load_native_certs().certs;
69 for cert in certs {
70 root_cert_store.add(cert)?;
71 }
72 let mut versions = tokio_rustls::rustls::ALL_VERSIONS.to_vec();
73 if let Some(min_tls_version) = self.min_tls_version {
74 versions.retain(|&supported_version| {
75 match tls::Version::from_tls(supported_version.version) {
76 Some(version) => version >= min_tls_version,
77 None => true,
80 }
81 });
82 }
83 if let Some(max_tls_version) = self.max_tls_version {
84 versions.retain(|&supported_version| {
85 match tls::Version::from_tls(supported_version.version) {
86 Some(version) => version <= max_tls_version,
87 None => false,
88 }
89 });
90 }
91 if versions.is_empty() {
92 return Err(crate::errors::builder("empty supported tls versions"));
93 }
94 let provider = tokio_rustls::rustls::crypto::CryptoProvider::get_default()
95 .cloned()
96 .unwrap_or_else(|| std::sync::Arc::new(tokio_rustls::rustls::crypto::ring::default_provider()));
97 let signature_algorithms = provider.signature_verification_algorithms;
98 let config_builder =
99 tokio_rustls::rustls::ClientConfig::builder_with_provider(provider.clone())
100 .with_protocol_versions(&versions)
101 .map_err(|_| crate::errors::builder("invalid TLS versions"))?;
102 let config_builder = if !self.certs_verification {
103 config_builder
104 .dangerous()
105 .with_custom_certificate_verifier(std::sync::Arc::new(tls::rustls::NoVerifier))
106 } else if !self.hostname_verification {
107 config_builder
108 .dangerous()
109 .with_custom_certificate_verifier(std::sync::Arc::new(tls::rustls::IgnoreHostname::new(
110 root_cert_store,
111 signature_algorithms,
112 )))
113 } else {
114 config_builder.with_root_certificates(root_cert_store)
115 };
116 let rustls_config = if let Some(id) = self.identity.clone() {
117 id.add_to_tls(config_builder)?
118 } else {
119 config_builder.with_no_client_auth()
120 };
121 #[cfg(feature = "http2")]
122 let rustls_config = {
123 let mut config = rustls_config;
124 if self.http2 {
125 config.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()];
126 }
127 config
128 };
129 Ok(std::sync::Arc::new(tls::rustls::RustlsTlsConnector::new(
130 tokio_rustls::TlsConnector::from(std::sync::Arc::new(rustls_config)),
131 connect_timeout,
132 )))
133 }
134}
135#[cfg(feature = "tls")]
136impl Default for TlsConfig {
137 fn default() -> Self {
138 Self {
139 #[cfg(feature = "http2")]
140 http2: false,
141 hostname_verification: true,
142 certs_verification: true,
143 min_tls_version: None,
144 max_tls_version: None,
145 tls_sni: true,
146 identity: None,
147 certificate: vec![],
148 }
149 }
150}
151
152impl ConnectorBuilder {
153 #[cfg(feature = "http2")]
154 pub fn enable_http2(mut self, http2: bool) -> Self {
156 self.tls_config.http2 = http2;
157 self
158 }
159 #[cfg(feature = "tls")]
160 pub fn hostname_verification(mut self, value: bool) -> ConnectorBuilder {
164 self.tls_config.hostname_verification = value;
165 self
166 }
167 #[cfg(feature = "tls")]
168 pub fn certs_verification(mut self, value: bool) -> ConnectorBuilder {
172 self.tls_config.certs_verification = value;
173 self
174 }
175 pub fn nodelay(mut self, value: bool) -> ConnectorBuilder {
179 self.nodelay = value;
180 self
181 }
182 pub fn keepalive(mut self, value: bool) -> ConnectorBuilder {
186 self.keepalive = value;
187 self
188 }
189 #[cfg(feature = "tls")]
193 pub fn tls_sni(mut self, value: bool) -> ConnectorBuilder {
194 self.tls_config.tls_sni = value;
195 self
196 }
197 #[cfg(feature = "tls")]
199 pub fn certificate(mut self, value: Vec<Certificate>) -> ConnectorBuilder {
200 self.tls_config.certificate = value;
201 self
202 }
203 #[cfg(feature = "tls")]
205 pub fn identity(mut self, value: Identity) -> ConnectorBuilder {
206 self.tls_config.identity = Some(value);
207 self
208 }
209 pub fn read_timeout(mut self, timeout: Option<Duration>) -> ConnectorBuilder {
217 self.read_timeout = timeout;
218 self
219 }
220 pub fn write_timeout(mut self, timeout: Option<Duration>) -> ConnectorBuilder {
228 self.write_timeout = timeout;
229 self
230 }
231 pub fn connect_timeout(mut self, timeout: Option<Duration>) -> ConnectorBuilder {
240 self.connect_timeout = timeout;
241 self
242 }
243 pub fn proxy(mut self, addr: Option<Proxy>) -> ConnectorBuilder {
251 self.proxy = addr;
252 self
253 }
254 #[cfg(feature = "tls")]
262 pub fn min_tls_version(mut self, version: Option<tls::Version>) -> ConnectorBuilder {
263 self.tls_config.min_tls_version = version;
264 self
265 }
266 #[cfg(feature = "tls")]
274 pub fn max_tls_version(mut self, version: Option<tls::Version>) -> ConnectorBuilder {
275 self.tls_config.max_tls_version = version;
276 self
277 }
278
279 #[cfg(feature = "tls")]
313 pub fn custom_tls_connector(
314 mut self,
315 connector: std::sync::Arc<dyn CustomTlsConnector>,
316 ) -> ConnectorBuilder {
317 self.custom_tls_connector = Some(connector);
318 self
319 }
320}
321
322impl ConnectorBuilder {
323 pub fn build(&self) -> Result<Connector> {
325 #[cfg(feature = "tls")]
326 let tls = {
327 if let Some(custom) = &self.custom_tls_connector {
329 custom.clone()
330 } else {
331 #[cfg(feature = "rustls")]
332 {
333 self.tls_config.custom(self.connect_timeout)?
335 }
336 #[cfg(not(feature = "rustls"))]
337 {
338 return Err(crate::errors::builder(
339 "TLS feature enabled without backend: please enable 'rustls' feature, or provide a custom TLS connector using .custom_tls_connector()",
340 ));
341 }
342 }
343 };
344 let conn = Connector {
345 connect_timeout: self.connect_timeout,
346 nodelay: self.nodelay,
347 keepalive: self.keepalive,
348 read_timeout: self.read_timeout,
349 write_timeout: self.write_timeout,
350 proxy: self.proxy.clone(),
351 #[cfg(feature = "tls")]
352 tls,
353 };
354 Ok(conn)
355 }
356}
357
358pub struct Connector {
361 connect_timeout: Option<Duration>,
362 nodelay: bool,
363 keepalive: bool,
364 read_timeout: Option<Duration>,
365 write_timeout: Option<Duration>,
366 proxy: Option<Proxy>,
367 #[cfg(feature = "tls")]
368 tls: std::sync::Arc<dyn CustomTlsConnector>,
369}
370
371impl PartialEq for Connector {
372 fn eq(&self, _other: &Self) -> bool {
373 true
374 }
375}
376
377impl Connector {
378 pub async fn connect_with_addr<S: Into<SocketAddr>>(&self, addr: S) -> Result<Socket> {
380 let addr = addr.into();
381 let raw_socket = RawSocket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
382 raw_socket.set_nonblocking(true)?;
383 let socket = TcpSocket::from_std_stream(raw_socket.into());
387 if self.nodelay {
388 socket.set_nodelay(self.nodelay)?;
389 }
390 if self.keepalive {
391 socket.set_keepalive(self.keepalive)?;
392 }
393 let s = match self.connect_timeout {
394 None => socket.connect(addr).await?,
395 Some(timeout) => tokio::time::timeout(timeout, socket.connect(addr))
396 .await
397 .map_err(|x| crate::errors::new_io_error(std::io::ErrorKind::TimedOut, &x.to_string()))??,
398 };
399 Ok(Socket::new(
400 StreamWrapper::Tcp(s),
401 self.read_timeout,
402 self.write_timeout,
403 ))
404 }
405 pub async fn connect_with_uri(&self, target: &http::Uri) -> Result<Socket> {
407 ProxySocket::new(target, &self.proxy)
408 .conn_with_connector(self)
409 .await
410 }
411 #[cfg(feature = "tls")]
412 pub async fn upgrade_to_tls(&self, stream: Socket, domain: &str) -> Result<Socket> {
414 self.tls.connect(domain, stream).await
415 }
416}
417
418impl Default for Connector {
420 fn default() -> Self {
421 ConnectorBuilder::default()
422 .build()
423 .expect("new default connector failure")
424 }
425}