1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4 clippy::must_use_candidate,
5 clippy::unwrap_in_result,
6 clippy::panic_in_result_fn
7)]
8#![allow(clippy::large_enum_variant, clippy::result_large_err)]
9
10use cfg_if::cfg_if;
62use std::{
63 convert::TryFrom,
64 error::Error,
65 fmt,
66 io::{self, IoSlice, IoSliceMut, Read, Write},
67 net::{TcpStream as StdTcpStream, ToSocketAddrs},
68 ops::{Deref, DerefMut},
69 time::Duration,
70};
71
72#[cfg(feature = "rustls")]
73mod rustls_impl;
74#[cfg(feature = "rustls")]
75pub use rustls_impl::*;
76
77#[cfg(feature = "native-tls")]
78mod native_tls_impl;
79#[cfg(feature = "native-tls")]
80pub use native_tls_impl::*;
81
82#[cfg(feature = "openssl")]
83mod openssl_impl;
84#[cfg(feature = "openssl")]
85pub use openssl_impl::*;
86
87#[cfg(feature = "futures")]
88mod futures;
89#[cfg(feature = "futures")]
90pub use futures::*;
91
92#[non_exhaustive]
94pub enum TcpStream {
95 Plain(StdTcpStream),
97 #[cfg(feature = "native-tls")]
98 NativeTls(NativeTlsStream),
100 #[cfg(feature = "openssl")]
101 Openssl(OpensslStream),
103 #[cfg(feature = "rustls")]
104 Rustls(RustlsStream),
106}
107
108#[derive(Default, Debug, PartialEq)]
110pub struct TLSConfig<'data, 'key, 'chain> {
111 pub identity: Option<Identity<'data, 'key>>,
113 pub cert_chain: Option<&'chain str>,
115}
116
117#[derive(Clone, Default, Debug, PartialEq)]
119pub struct OwnedTLSConfig {
120 pub identity: Option<OwnedIdentity>,
122 pub cert_chain: Option<String>,
124}
125
126impl OwnedTLSConfig {
127 #[must_use]
129 pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
130 TLSConfig {
131 identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
132 cert_chain: self.cert_chain.as_deref(),
133 }
134 }
135}
136
137#[derive(Debug, PartialEq)]
141pub enum Identity<'data, 'key> {
142 PKCS12 {
144 der: &'data [u8],
146 password: &'key str,
148 },
149 PKCS8 {
151 pem: &'data [u8],
153 key: &'key [u8],
155 },
156}
157
158#[derive(Clone, Debug, PartialEq)]
162pub enum OwnedIdentity {
163 PKCS12 {
165 der: Vec<u8>,
167 password: String,
169 },
170 PKCS8 {
172 pem: Vec<u8>,
174 key: Vec<u8>,
176 },
177}
178
179impl OwnedIdentity {
180 #[must_use]
182 pub fn as_ref(&self) -> Identity<'_, '_> {
183 match self {
184 Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
185 Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
186 }
187 }
188}
189
190pub type HandshakeResult = Result<TcpStream, HandshakeError>;
192
193impl TcpStream {
194 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
196 connect_std(addr, None).and_then(Self::try_from)
197 }
198
199 pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
201 connect_std(addr, Some(timeout)).and_then(Self::try_from)
202 }
203
204 pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
206 Self::try_from(stream)
207 }
208
209 pub fn is_readable(&self) -> io::Result<()> {
211 self.deref().read(&mut []).map(|_| ())
212 }
213
214 pub fn is_writable(&self) -> io::Result<()> {
216 is_writable(self.deref())
217 }
218
219 pub fn try_connect(&mut self) -> io::Result<bool> {
224 try_connect(self)
225 }
226
227 pub fn into_tls(
229 self,
230 domain: &str,
231 config: TLSConfig<'_, '_, '_>,
232 ) -> Result<Self, HandshakeError> {
233 into_tls_impl(self, domain, config)
234 }
235
236 #[cfg(feature = "native-tls")]
237 pub fn into_native_tls(
239 self,
240 connector: &NativeTlsConnector,
241 domain: &str,
242 ) -> Result<Self, HandshakeError> {
243 Ok(connector.connect(domain, self.into_plain()?)?.into())
244 }
245
246 #[cfg(feature = "openssl")]
247 pub fn into_openssl(
249 self,
250 connector: &OpensslConnector,
251 domain: &str,
252 ) -> Result<Self, HandshakeError> {
253 Ok(connector.connect(domain, self.into_plain()?)?.into())
254 }
255
256 #[cfg(feature = "rustls")]
257 pub fn into_rustls(
259 self,
260 connector: &RustlsConnector,
261 domain: &str,
262 ) -> Result<Self, HandshakeError> {
263 Ok(connector.connect(domain, self.into_plain()?)?.into())
264 }
265
266 #[allow(irrefutable_let_patterns)]
267 fn into_plain(self) -> Result<StdTcpStream, io::Error> {
268 if let Self::Plain(plain) = self {
269 Ok(plain)
270 } else {
271 Err(io::Error::new(
272 io::ErrorKind::AlreadyExists,
273 "already a TLS stream",
274 ))
275 }
276 }
277}
278
279fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
280 if let Some(timeout) = timeout {
281 let addrs = addr.to_socket_addrs()?;
282 let mut err = None;
283 for addr in addrs {
284 match StdTcpStream::connect_timeout(&addr, timeout) {
285 Ok(stream) => return Ok(stream),
286 Err(error) => err = Some(error),
287 }
288 }
289 Err(err.unwrap_or_else(|| {
290 io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
291 }))
292 } else {
293 StdTcpStream::connect(addr)
294 }
295}
296
297fn try_connect(stream: &mut StdTcpStream) -> io::Result<bool> {
298 match is_writable(stream) {
299 Ok(()) => Ok(true),
300 Err(err)
301 if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected].contains(&err.kind()) =>
302 {
303 Ok(false)
304 }
305 Err(err) => Err(err),
306 }
307}
308
309fn is_writable(mut stream: &StdTcpStream) -> io::Result<()> {
310 stream.write(&[]).map(|_| ())
311}
312
313fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
314 cfg_if! {
315 if #[cfg(feature = "rustls-platform-verifier")] {
316 into_rustls_impl(s, RustlsConnectorConfig::new_with_platform_verifier(), domain, config)
317 } else if #[cfg(feature = "rustls-native-certs")] {
318 into_rustls_impl(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
319 } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
320 into_rustls_impl(s, RustlsConnectorConfig::new_with_webpki_root_certs(), domain, config)
321 } else if #[cfg(feature = "rustls")] {
322 into_rustls_impl(s, RustlsConnectorConfig::default(), domain, config)
323 } else if #[cfg(feature = "openssl")] {
324 into_openssl_impl(s, domain, config)
325 } else if #[cfg(feature = "native-tls")] {
326 into_native_tls_impl(s, domain, config)
327 } else {
328 let _ = (domain, config);
329 Ok(TcpStream::Plain(s.into_plain()?))
330 }
331 }
332}
333
334impl TryFrom<StdTcpStream> for TcpStream {
335 type Error = io::Error;
336
337 fn try_from(s: StdTcpStream) -> io::Result<Self> {
338 s.set_nodelay(true)?;
339 let mut this = Self::Plain(s);
340 this.try_connect()?;
341 Ok(this)
342 }
343}
344
345impl Deref for TcpStream {
346 type Target = StdTcpStream;
347
348 fn deref(&self) -> &Self::Target {
349 match self {
350 Self::Plain(plain) => plain,
351 #[cfg(feature = "native-tls")]
352 Self::NativeTls(tls) => tls.get_ref(),
353 #[cfg(feature = "openssl")]
354 Self::Openssl(tls) => tls.get_ref(),
355 #[cfg(feature = "rustls")]
356 Self::Rustls(tls) => tls.get_ref(),
357 }
358 }
359}
360
361impl DerefMut for TcpStream {
362 fn deref_mut(&mut self) -> &mut Self::Target {
363 match self {
364 Self::Plain(plain) => plain,
365 #[cfg(feature = "native-tls")]
366 Self::NativeTls(tls) => tls.get_mut(),
367 #[cfg(feature = "openssl")]
368 Self::Openssl(tls) => tls.get_mut(),
369 #[cfg(feature = "rustls")]
370 Self::Rustls(tls) => tls.get_mut(),
371 }
372 }
373}
374
375macro_rules! fwd_impl {
376 ($self:ident, $method:ident, $($args:expr),*) => {
377 match $self {
378 Self::Plain(plain) => plain.$method($($args),*),
379 #[cfg(feature = "native-tls")]
380 Self::NativeTls(tls) => tls.$method($($args),*),
381 #[cfg(feature = "openssl")]
382 Self::Openssl(tls) => tls.$method($($args),*),
383 #[cfg(feature = "rustls")]
384 Self::Rustls(tls) => tls.$method($($args),*),
385 }
386 };
387}
388
389impl Read for TcpStream {
390 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
391 fwd_impl!(self, read, buf)
392 }
393
394 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
395 fwd_impl!(self, read_vectored, bufs)
396 }
397
398 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
399 fwd_impl!(self, read_to_end, buf)
400 }
401
402 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
403 fwd_impl!(self, read_to_string, buf)
404 }
405
406 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
407 fwd_impl!(self, read_exact, buf)
408 }
409}
410
411impl Write for TcpStream {
412 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
413 fwd_impl!(self, write, buf)
414 }
415
416 fn flush(&mut self) -> io::Result<()> {
417 fwd_impl!(self, flush,)
418 }
419
420 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
421 fwd_impl!(self, write_vectored, bufs)
422 }
423
424 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
425 fwd_impl!(self, write_all, buf)
426 }
427
428 fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
429 fwd_impl!(self, write_fmt, fmt)
430 }
431}
432
433impl fmt::Debug for TcpStream {
434 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 f.debug_struct("TcpStream")
436 .field("inner", self.deref())
437 .finish()
438 }
439}
440
441#[derive(Debug)]
443pub enum MidHandshakeTlsStream {
444 Plain(TcpStream),
446 #[cfg(feature = "native-tls")]
447 NativeTls(NativeTlsMidHandshakeTlsStream),
449 #[cfg(feature = "openssl")]
450 Openssl(OpensslMidHandshakeTlsStream),
452 #[cfg(feature = "rustls")]
453 Rustls(RustlsMidHandshakeTlsStream),
455}
456
457impl MidHandshakeTlsStream {
458 #[must_use]
460 pub fn get_ref(&self) -> &StdTcpStream {
461 match self {
462 Self::Plain(mid) => mid,
463 #[cfg(feature = "native-tls")]
464 Self::NativeTls(mid) => mid.get_ref(),
465 #[cfg(feature = "openssl")]
466 Self::Openssl(mid) => mid.get_ref(),
467 #[cfg(feature = "rustls")]
468 Self::Rustls(mid) => mid.get_ref(),
469 }
470 }
471
472 #[must_use]
474 pub fn get_mut(&mut self) -> &mut StdTcpStream {
475 match self {
476 Self::Plain(mid) => mid,
477 #[cfg(feature = "native-tls")]
478 Self::NativeTls(mid) => mid.get_mut(),
479 #[cfg(feature = "openssl")]
480 Self::Openssl(mid) => mid.get_mut(),
481 #[cfg(feature = "rustls")]
482 Self::Rustls(mid) => mid.get_mut(),
483 }
484 }
485
486 pub fn handshake(mut self) -> HandshakeResult {
488 if !try_connect(self.get_mut())? {
489 return Err(HandshakeError::WouldBlock(self));
490 }
491
492 Ok(match self {
493 Self::Plain(mid) => mid,
494 #[cfg(feature = "native-tls")]
495 Self::NativeTls(mid) => mid.handshake()?.into(),
496 #[cfg(feature = "openssl")]
497 Self::Openssl(mid) => mid.handshake()?.into(),
498 #[cfg(feature = "rustls")]
499 Self::Rustls(mid) => mid.handshake()?.into(),
500 })
501 }
502}
503
504impl From<TcpStream> for MidHandshakeTlsStream {
505 fn from(mid: TcpStream) -> Self {
506 Self::Plain(mid)
507 }
508}
509
510impl fmt::Display for MidHandshakeTlsStream {
511 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
512 f.write_str("MidHandshakeTlsStream")
513 }
514}
515
516#[derive(Debug)]
518pub enum HandshakeError {
519 WouldBlock(MidHandshakeTlsStream),
521 Failure(io::Error),
523}
524
525impl HandshakeError {
526 pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
528 match self {
529 Self::WouldBlock(mid) => Ok(mid),
530 Self::Failure(error) => Err(error),
531 }
532 }
533}
534
535impl fmt::Display for HandshakeError {
536 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
537 match self {
538 Self::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
539 Self::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
540 }
541 }
542}
543
544impl Error for HandshakeError {
545 fn source(&self) -> Option<&(dyn Error + 'static)> {
546 match self {
547 Self::Failure(err) => Some(err),
548 _ => None,
549 }
550 }
551}
552
553impl From<io::Error> for HandshakeError {
554 fn from(err: io::Error) -> Self {
555 Self::Failure(err)
556 }
557}
558
559mod sys;