slinger/
connector.rs

1use crate::errors::Result;
2use crate::proxy::{Proxy, ProxySocket};
3use crate::socket::{Socket, StreamWrapper};
4#[cfg(feature = "tls")]
5use crate::tls::{self, Certificate, CustomTlsConnector, Identity};
6use socket2::Socket as RawSocket;
7use socket2::{Domain, Protocol, Type};
8use std::net::SocketAddr;
9use std::time::Duration;
10use tokio::net::TcpSocket;
11
12/// ConnectorBuilder
13#[derive(Clone)]
14pub struct ConnectorBuilder {
15  read_timeout: Option<Duration>,
16  write_timeout: Option<Duration>,
17  connect_timeout: Option<Duration>,
18  nodelay: bool,
19  keepalive: bool,
20  proxy: Option<Proxy>,
21  #[cfg(feature = "tls")]
22  tls_config: TlsConfig,
23  #[cfg(feature = "tls")]
24  custom_tls_connector: Option<std::sync::Arc<dyn CustomTlsConnector>>,
25}
26
27impl Default for ConnectorBuilder {
28  fn default() -> Self {
29    Self {
30      read_timeout: Some(Duration::from_secs(30)),
31      write_timeout: Some(Duration::from_secs(30)),
32      connect_timeout: Some(Duration::from_secs(10)),
33      nodelay: false,
34      keepalive: false,
35      proxy: None,
36      #[cfg(feature = "tls")]
37      tls_config: TlsConfig::default(),
38      #[cfg(feature = "tls")]
39      custom_tls_connector: None,
40    }
41  }
42}
43
44#[cfg(feature = "tls")]
45/// TLS related configuration extracted from ConnectorBuilder.
46#[derive(Clone)]
47pub struct TlsConfig {
48  #[cfg(feature = "http2")]
49  pub http2: bool,
50  pub hostname_verification: bool,
51  pub certs_verification: bool,
52  pub min_tls_version: Option<tls::Version>,
53  pub max_tls_version: Option<tls::Version>,
54  pub tls_sni: bool,
55  pub identity: Option<Identity>,
56  pub certificate: Vec<Certificate>,
57}
58#[cfg(feature = "rustls")]
59impl TlsConfig {
60  fn custom(
61    &self,
62    connect_timeout: Option<Duration>,
63  ) -> Result<std::sync::Arc<dyn CustomTlsConnector>> {
64    let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
65    for cert in self.certificate.clone() {
66      cert.add_to_tls(&mut root_cert_store)?;
67    }
68    let certs = rustls_native_certs::load_native_certs().certs;
69    for cert in certs {
70      root_cert_store.add(cert)?;
71    }
72    let mut versions = tokio_rustls::rustls::ALL_VERSIONS.to_vec();
73    if let Some(min_tls_version) = self.min_tls_version {
74      versions.retain(|&supported_version| {
75        match tls::Version::from_tls(supported_version.version) {
76          Some(version) => version >= min_tls_version,
77          // Assume it's so new we don't know about it, allow it
78          // (as of writing this is unreachable)
79          None => true,
80        }
81      });
82    }
83    if let Some(max_tls_version) = self.max_tls_version {
84      versions.retain(|&supported_version| {
85        match tls::Version::from_tls(supported_version.version) {
86          Some(version) => version <= max_tls_version,
87          None => false,
88        }
89      });
90    }
91    if versions.is_empty() {
92      return Err(crate::errors::builder("empty supported tls versions"));
93    }
94    let provider = tokio_rustls::rustls::crypto::CryptoProvider::get_default()
95      .cloned()
96      .unwrap_or_else(|| std::sync::Arc::new(tokio_rustls::rustls::crypto::ring::default_provider()));
97    let signature_algorithms = provider.signature_verification_algorithms;
98    let config_builder =
99      tokio_rustls::rustls::ClientConfig::builder_with_provider(provider.clone())
100        .with_protocol_versions(&versions)
101        .map_err(|_| crate::errors::builder("invalid TLS versions"))?;
102    let config_builder = if !self.certs_verification {
103      config_builder
104        .dangerous()
105        .with_custom_certificate_verifier(std::sync::Arc::new(tls::rustls::NoVerifier))
106    } else if !self.hostname_verification {
107      config_builder
108        .dangerous()
109        .with_custom_certificate_verifier(std::sync::Arc::new(tls::rustls::IgnoreHostname::new(
110          root_cert_store,
111          signature_algorithms,
112        )))
113    } else {
114      config_builder.with_root_certificates(root_cert_store)
115    };
116    let rustls_config = if let Some(id) = self.identity.clone() {
117      id.add_to_tls(config_builder)?
118    } else {
119      config_builder.with_no_client_auth()
120    };
121    #[cfg(feature = "http2")]
122    let rustls_config = {
123      let mut config = rustls_config;
124      if self.http2 {
125        config.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()];
126      }
127      config
128    };
129    Ok(std::sync::Arc::new(tls::rustls::RustlsTlsConnector::new(
130      tokio_rustls::TlsConnector::from(std::sync::Arc::new(rustls_config)),
131      connect_timeout,
132    )))
133  }
134}
135#[cfg(feature = "tls")]
136impl Default for TlsConfig {
137  fn default() -> Self {
138    Self {
139      #[cfg(feature = "http2")]
140      http2: false,
141      hostname_verification: true,
142      certs_verification: true,
143      min_tls_version: None,
144      max_tls_version: None,
145      tls_sni: true,
146      identity: None,
147      certificate: vec![],
148    }
149  }
150}
151
152impl ConnectorBuilder {
153  #[cfg(feature = "http2")]
154  /// Enable HTTP/2 support.
155  pub fn enable_http2(mut self, http2: bool) -> Self {
156    self.tls_config.http2 = http2;
157    self
158  }
159  #[cfg(feature = "tls")]
160  /// Controls the use of hostname verification.
161  ///
162  /// Defaults to `true`.
163  pub fn hostname_verification(mut self, value: bool) -> ConnectorBuilder {
164    self.tls_config.hostname_verification = value;
165    self
166  }
167  #[cfg(feature = "tls")]
168  /// Controls the use of certificate validation.
169  ///
170  /// Defaults to `true`.
171  pub fn certs_verification(mut self, value: bool) -> ConnectorBuilder {
172    self.tls_config.certs_verification = value;
173    self
174  }
175  /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
176  ///
177  /// Default is `false`.
178  pub fn nodelay(mut self, value: bool) -> ConnectorBuilder {
179    self.nodelay = value;
180    self
181  }
182  /// Sets value for the `SO_KEEPALIVE` option on this socket.
183  ///
184  /// Default is `false`.
185  pub fn keepalive(mut self, value: bool) -> ConnectorBuilder {
186    self.keepalive = value;
187    self
188  }
189  /// Controls the use of Server Name Indication (SNI).
190  ///
191  /// Defaults to `true`.
192  #[cfg(feature = "tls")]
193  pub fn tls_sni(mut self, value: bool) -> ConnectorBuilder {
194    self.tls_config.tls_sni = value;
195    self
196  }
197  /// Adds a certificate to the set of roots that the connector will trust.
198  #[cfg(feature = "tls")]
199  pub fn certificate(mut self, value: Vec<Certificate>) -> ConnectorBuilder {
200    self.tls_config.certificate = value;
201    self
202  }
203  /// Sets the identity to be used for client certificate authentication.
204  #[cfg(feature = "tls")]
205  pub fn identity(mut self, value: Identity) -> ConnectorBuilder {
206    self.tls_config.identity = Some(value);
207    self
208  }
209  /// Enables a read timeout.
210  ///
211  /// The timeout applies to each read operation, and resets after a
212  /// successful read. This is more appropriate for detecting stalled
213  /// connections when the size isn't known beforehand.
214  ///
215  /// Default is 30 seconds.
216  pub fn read_timeout(mut self, timeout: Option<Duration>) -> ConnectorBuilder {
217    self.read_timeout = timeout;
218    self
219  }
220  /// Enables a write timeout.
221  ///
222  /// The timeout applies to each read operation, and resets after a
223  /// successful read. This is more appropriate for detecting stalled
224  /// connections when the size isn't known beforehand.
225  ///
226  /// Default is 30 seconds.
227  pub fn write_timeout(mut self, timeout: Option<Duration>) -> ConnectorBuilder {
228    self.write_timeout = timeout;
229    self
230  }
231  /// Set a timeout for only the connect phase of a `Client`.
232  ///
233  /// Default is 10 seconds.
234  ///
235  /// # Note
236  ///
237  /// This **requires** the futures be executed in a tokio runtime with
238  /// a tokio timer enabled.
239  pub fn connect_timeout(mut self, timeout: Option<Duration>) -> ConnectorBuilder {
240    self.connect_timeout = timeout;
241    self
242  }
243  // Proxy options
244
245  /// Add a `Proxy` to the list of proxies the `Client` will use.
246  ///
247  /// # Note
248  ///
249  /// Adding a proxy will disable the automatic usage of the "system" proxy.
250  pub fn proxy(mut self, addr: Option<Proxy>) -> ConnectorBuilder {
251    self.proxy = addr;
252    self
253  }
254  /// Set the minimum required TLS version for connections.
255  ///
256  /// By default, the `native_tls::Protocol` default is used.
257  ///
258  /// # Optional
259  ///
260  /// This requires the optional `tls` feature to be enabled.
261  #[cfg(feature = "tls")]
262  pub fn min_tls_version(mut self, version: Option<tls::Version>) -> ConnectorBuilder {
263    self.tls_config.min_tls_version = version;
264    self
265  }
266  /// Set the maximum required TLS version for connections.
267  ///
268  /// By default, the `native_tls::Protocol` default is used.
269  ///
270  /// # Optional
271  ///
272  /// This requires the optional `tls` feature to be enabled.
273  #[cfg(feature = "tls")]
274  pub fn max_tls_version(mut self, version: Option<tls::Version>) -> ConnectorBuilder {
275    self.tls_config.max_tls_version = version;
276    self
277  }
278
279  /// Set a custom TLS connector for custom TLS handshake implementations.
280  ///
281  /// This is available when the `tls` feature is enabled. It allows you to provide
282  /// your own TLS implementation using libraries like openssl, boringssl, native-tls,
283  /// or any other TLS library. When the `rustls` feature is enabled, a default rustls
284  /// implementation is used if no custom connector is provided.
285  ///
286  /// # Example
287  ///
288  /// ```ignore
289  /// use slinger::connector::{ConnectorBuilder, CustomTlsConnector};
290  /// use slinger::{Result, Socket};
291  /// use std::sync::Arc;
292  ///
293  /// struct MyTlsConnector;
294  ///
295  /// impl CustomTlsConnector for MyTlsConnector {
296  ///     fn connect<'a>(
297  ///         &'a self,
298  ///         domain: &'a str,
299  ///         stream: Socket,
300  ///     ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Socket>> + Send + 'a>> {
301  ///         Box::pin(async move {
302  ///             // Your custom TLS handshake logic here
303  ///             todo!()
304  ///         })
305  ///     }
306  /// }
307  ///
308  /// let connector = ConnectorBuilder::default()
309  ///     .custom_tls_connector(Arc::new(MyTlsConnector))
310  ///     .build()?;
311  /// ```
312  #[cfg(feature = "tls")]
313  pub fn custom_tls_connector(
314    mut self,
315    connector: std::sync::Arc<dyn CustomTlsConnector>,
316  ) -> ConnectorBuilder {
317    self.custom_tls_connector = Some(connector);
318    self
319  }
320}
321
322impl ConnectorBuilder {
323  /// Combine the configuration of this builder with a connector to create a `Connector`.
324  pub fn build(&self) -> Result<Connector> {
325    #[cfg(feature = "tls")]
326    let tls = {
327      // custom connector takes precedence; otherwise if rustls is enabled, build it from config
328      if let Some(custom) = &self.custom_tls_connector {
329        custom.clone()
330      } else {
331        #[cfg(feature = "rustls")]
332        {
333          // Try to convert the builder into a rustls connector. Clone self because TryInto consumes it.
334          self.tls_config.custom(self.connect_timeout)?
335        }
336        #[cfg(not(feature = "rustls"))]
337        {
338          return Err(crate::errors::builder(
339            "TLS feature enabled without backend: please enable 'rustls' feature, or provide a custom TLS connector using .custom_tls_connector()",
340          ));
341        }
342      }
343    };
344    let conn = Connector {
345      connect_timeout: self.connect_timeout,
346      nodelay: self.nodelay,
347      keepalive: self.keepalive,
348      read_timeout: self.read_timeout,
349      write_timeout: self.write_timeout,
350      proxy: self.proxy.clone(),
351      #[cfg(feature = "tls")]
352      tls,
353    };
354    Ok(conn)
355  }
356}
357
358/// Connector
359// #[derive(Debug)]
360pub struct Connector {
361  connect_timeout: Option<Duration>,
362  nodelay: bool,
363  keepalive: bool,
364  read_timeout: Option<Duration>,
365  write_timeout: Option<Duration>,
366  proxy: Option<Proxy>,
367  #[cfg(feature = "tls")]
368  tls: std::sync::Arc<dyn CustomTlsConnector>,
369}
370
371impl PartialEq for Connector {
372  fn eq(&self, _other: &Self) -> bool {
373    true
374  }
375}
376
377impl Connector {
378  /// Connect to a remote endpoint with addr
379  pub async fn connect_with_addr<S: Into<SocketAddr>>(&self, addr: S) -> Result<Socket> {
380    let addr = addr.into();
381    let raw_socket = RawSocket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
382    raw_socket.set_nonblocking(true)?;
383    // 阻塞才能设置超时,异步在这设置没意义
384    // raw_socket.set_write_timeout(self.write_timeout)?;
385    // raw_socket.set_read_timeout(self.read_timeout)?;
386    let socket = TcpSocket::from_std_stream(raw_socket.into());
387    if self.nodelay {
388      socket.set_nodelay(self.nodelay)?;
389    }
390    if self.keepalive {
391      socket.set_keepalive(self.keepalive)?;
392    }
393    let s = match self.connect_timeout {
394      None => socket.connect(addr).await?,
395      Some(timeout) => tokio::time::timeout(timeout, socket.connect(addr))
396        .await
397        .map_err(|x| crate::errors::new_io_error(std::io::ErrorKind::TimedOut, &x.to_string()))??,
398    };
399    Ok(Socket::new(
400      StreamWrapper::Tcp(s),
401      self.read_timeout,
402      self.write_timeout,
403    ))
404  }
405  /// Connect to a remote endpoint with url
406  pub async fn connect_with_uri(&self, target: &http::Uri) -> Result<Socket> {
407    ProxySocket::new(target, &self.proxy)
408      .conn_with_connector(self)
409      .await
410  }
411  #[cfg(feature = "tls")]
412  /// A `Connector` will use transport layer security (TLS) by default to connect to destinations.
413  pub async fn upgrade_to_tls(&self, stream: Socket, domain: &str) -> Result<Socket> {
414    self.tls.connect(domain, stream).await
415  }
416}
417
418//
419impl Default for Connector {
420  fn default() -> Self {
421    ConnectorBuilder::default()
422      .build()
423      .expect("new default connector failure")
424  }
425}