1#![deny(missing_docs)]
2
3use cfg_if::cfg_if;
49use std::{
50 convert::TryFrom,
51 error::Error,
52 fmt,
53 io::{self, IoSlice, IoSliceMut, Read, Write},
54 net::{TcpStream as StdTcpStream, ToSocketAddrs},
55 ops::{Deref, DerefMut},
56 time::Duration,
57};
58
59#[cfg(feature = "native-tls")]
60pub use native_tls::TlsConnector as NativeTlsConnector;
62
63#[cfg(feature = "native-tls")]
64pub type NativeTlsStream = native_tls::TlsStream<TcpStream>;
66
67#[cfg(feature = "native-tls")]
68pub type NativeTlsMidHandshakeTlsStream = native_tls::MidHandshakeTlsStream<TcpStream>;
70
71#[cfg(feature = "native-tls")]
72pub type NativeTlsHandshakeError = native_tls::HandshakeError<TcpStream>;
74
75#[cfg(feature = "openssl")]
76pub use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod as OpenSslMethod};
78
79#[cfg(feature = "openssl")]
80pub type OpenSslStream = openssl::ssl::SslStream<TcpStream>;
82
83#[cfg(feature = "openssl")]
84pub type OpenSslMidHandshakeTlsStream = openssl::ssl::MidHandshakeSslStream<TcpStream>;
86
87#[cfg(feature = "openssl")]
88pub type OpenSslHandshakeError = openssl::ssl::HandshakeError<TcpStream>;
90
91#[cfg(feature = "openssl")]
92pub type OpenSslErrorStack = openssl::error::ErrorStack;
94
95#[cfg(feature = "rustls-common")]
96pub use rustls_connector::{RustlsConnector, RustlsConnectorConfig};
98
99#[cfg(feature = "rustls-common")]
100pub type RustlsStream = rustls_connector::TlsStream<TcpStream>;
102
103#[cfg(feature = "rustls-common")]
104pub type RustlsMidHandshakeTlsStream = rustls_connector::MidHandshakeTlsStream<TcpStream>;
106
107#[cfg(feature = "rustls-common")]
108pub type RustlsHandshakeError = rustls_connector::HandshakeError<TcpStream>;
110
111pub enum TcpStream {
113 Plain(StdTcpStream, bool),
115 #[cfg(feature = "native-tls")]
116 NativeTls(Box<NativeTlsStream>),
118 #[cfg(feature = "openssl")]
119 OpenSsl(Box<OpenSslStream>),
121 #[cfg(feature = "rustls-common")]
122 Rustls(Box<RustlsStream>),
124}
125
126#[derive(Default, Debug, PartialEq)]
128pub struct TLSConfig<'data, 'key, 'chain> {
129 pub identity: Option<Identity<'data, 'key>>,
131 pub cert_chain: Option<&'chain str>,
133}
134
135#[derive(Default, Debug, PartialEq)]
137pub struct OwnedTLSConfig {
138 pub identity: Option<OwnedIdentity>,
140 pub cert_chain: Option<String>,
142}
143
144impl OwnedTLSConfig {
145 #[must_use]
147 pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
148 TLSConfig {
149 identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
150 cert_chain: self.cert_chain.as_deref(),
151 }
152 }
153}
154
155#[derive(Debug, PartialEq)]
159pub enum Identity<'data, 'key> {
160 PKCS12 {
162 der: &'data [u8],
164 password: &'key str,
166 },
167 PKCS8 {
169 pem: &'data [u8],
171 key: &'key [u8],
173 },
174}
175
176#[derive(Debug, PartialEq)]
180pub enum OwnedIdentity {
181 PKCS12 {
183 der: Vec<u8>,
185 password: String,
187 },
188 PKCS8 {
190 pem: Vec<u8>,
192 key: Vec<u8>,
194 },
195}
196
197impl OwnedIdentity {
198 #[must_use]
200 pub fn as_ref(&self) -> Identity<'_, '_> {
201 match self {
202 Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
203 Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
204 }
205 }
206}
207
208pub type HandshakeResult = Result<TcpStream, HandshakeError>;
210
211impl TcpStream {
212 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
214 connect_std(addr, None).and_then(Self::try_from)
215 }
216
217 pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
219 connect_std(addr, Some(timeout)).and_then(Self::try_from)
220 }
221
222 pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
224 Self::try_from(stream)
225 }
226
227 #[must_use]
229 #[allow(irrefutable_let_patterns)]
230 pub fn is_connected(&self) -> bool {
231 if let Self::Plain(_, connected) = self {
232 *connected
233 } else {
234 true
235 }
236 }
237
238 #[allow(irrefutable_let_patterns)]
243 pub fn try_connect(&mut self) -> io::Result<bool> {
244 if self.is_connected() {
245 return Ok(true);
246 }
247 match self.is_writable() {
248 Ok(()) => {
249 if let Self::Plain(_, connected) = self {
250 *connected = true;
251 }
252 Ok(true)
253 }
254 Err(err)
255 if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected]
256 .contains(&err.kind()) =>
257 {
258 Ok(false)
259 }
260 Err(err) => Err(err),
261 }
262 }
263
264 pub fn into_tls(
266 self,
267 domain: &str,
268 config: TLSConfig<'_, '_, '_>,
269 ) -> Result<Self, HandshakeError> {
270 into_tls_impl(self, domain, config)
271 }
272
273 #[cfg(feature = "native-tls")]
274 pub fn into_native_tls(
276 self,
277 connector: &NativeTlsConnector,
278 domain: &str,
279 ) -> Result<Self, HandshakeError> {
280 Ok(connector.connect(domain, self.into_plain()?)?.into())
281 }
282
283 #[cfg(feature = "openssl")]
284 pub fn into_openssl(
286 self,
287 connector: &OpenSslConnector,
288 domain: &str,
289 ) -> Result<Self, HandshakeError> {
290 Ok(connector.connect(domain, self.into_plain()?)?.into())
291 }
292
293 #[cfg(feature = "rustls-common")]
294 pub fn into_rustls(
296 self,
297 connector: &RustlsConnector,
298 domain: &str,
299 ) -> Result<Self, HandshakeError> {
300 Ok(connector.connect(domain, self.into_plain()?)?.into())
301 }
302
303 #[allow(irrefutable_let_patterns)]
304 fn into_plain(self) -> Result<TcpStream, io::Error> {
305 if let TcpStream::Plain(plain, connected) = self {
306 Ok(TcpStream::Plain(plain, connected))
307 } else {
308 Err(io::Error::new(
309 io::ErrorKind::AlreadyExists,
310 "already a TLS stream",
311 ))
312 }
313 }
314}
315
316fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
317 let stream = connect_std_raw(addr, timeout)?;
318 stream.set_nodelay(true)?;
319 Ok(stream)
320}
321
322fn connect_std_raw<A: ToSocketAddrs>(
323 addr: A,
324 timeout: Option<Duration>,
325) -> io::Result<StdTcpStream> {
326 let mut addrs = addr.to_socket_addrs()?;
327 let mut err = None;
328 if let Some(timeout) = timeout {
329 if let Some(addr) = addrs.next() {
330 match StdTcpStream::connect_timeout(&addr, timeout) {
331 Ok(stream) => return Ok(stream),
332 Err(error) => err = Some(error),
333 }
334 }
335 }
336 for addr in addrs {
337 match StdTcpStream::connect(addr) {
338 Ok(stream) => return Ok(stream),
339 Err(error) => err = Some(error),
340 }
341 }
342 Err(err.unwrap_or_else(|| {
343 io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
344 }))
345}
346
347#[cfg(feature = "rustls-common")]
348fn into_rustls_common(
349 s: TcpStream,
350 mut c: RustlsConnectorConfig,
351 domain: &str,
352 config: TLSConfig<'_, '_, '_>,
353) -> HandshakeResult {
354 use rustls_connector::rustls_pki_types::{
355 pem::PemObject, CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer,
356 };
357
358 if let Some(cert_chain) = config.cert_chain {
359 let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
360 let certs = rustls_pemfile::certs(&mut cert_chain)
361 .collect::<Result<Vec<_>, _>>()
362 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
363 c.add_parsable_certificates(certs);
364 }
365 let connector = if let Some(identity) = config.identity {
366 let (certs, key) = match identity {
367 Identity::PKCS12 { der, password } => {
368 let pfx = p12_keystore::KeyStore::from_pkcs12(der, password)
369 .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
370 let Some((_, keychain)) = pfx.private_key_chain() else {
371 return Err(io::Error::new(
372 io::ErrorKind::Other,
373 "No private key in pkcs12 DER",
374 )
375 .into());
376 };
377 let certs = keychain
378 .chain()
379 .iter()
380 .map(|cert| CertificateDer::from(cert.as_der().to_vec()))
381 .collect();
382 (
383 certs,
384 PrivateKeyDer::from(PrivatePkcs8KeyDer::from(keychain.key().to_vec())),
385 )
386 }
387 Identity::PKCS8 { pem, key } => {
388 let mut cert_reader = std::io::BufReader::new(pem);
389 let certs = rustls_pemfile::certs(&mut cert_reader)
390 .collect::<Result<Vec<_>, _>>()
391 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
392 (
393 certs,
394 PrivateKeyDer::from_pem_slice(key)
395 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
396 )
397 }
398 };
399 c.connector_with_single_cert(certs, key)
400 .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
401 } else {
402 c.connector_with_no_client_auth()
403 };
404 s.into_rustls(&connector, domain)
405}
406
407cfg_if! {
408 if #[cfg(feature = "rustls-native-certs")] {
409 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
410 into_rustls_common(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
411 }
412 } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
413 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
414 into_rustls_common(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
415 }
416 } else if #[cfg(feature = "rustls-common")] {
417 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
418 into_rustls_common(s, RustlsConnectorConfig::default(), domain, config)
419 }
420 } else if #[cfg(feature = "openssl")] {
421 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
422 use openssl::x509::X509;
423
424 let mut builder = OpenSslConnector::builder(OpenSslMethod::tls())?;
425 if let Some(identity) = config.identity {
426 let (cert, pkey, chain) = match identity {
427 Identity::PKCS8 { pem, key } => {
428 let pkey = openssl::pkey::PKey::private_key_from_pem(key)?;
429 let mut chain = openssl::x509::X509::stack_from_pem(pem)?.into_iter();
430 let cert = chain.next();
431 (cert, Some(pkey), Some(chain.collect()))
432 }
433 Identity::PKCS12 { der, password } => {
434 let mut openssl_identity = openssl::pkcs12::Pkcs12::from_der(der)?.parse2(password)?;
435 (openssl_identity.cert, openssl_identity.pkey, openssl_identity.ca.take().map(|stack| stack.into_iter().collect::<Vec<_>>()))
436 },
437 };
438 if let Some(cert) = cert.as_ref() {
439 builder.set_certificate(cert)?;
440 }
441 if let Some(pkey) = pkey.as_ref() {
442 builder.set_private_key(pkey)?;
443 }
444 if let Some(chain) = chain.as_ref() {
445 for cert in chain.iter().rev() {
446 builder.add_extra_chain_cert(cert.to_owned())?;
447 }
448 }
449 }
450 if let Some(cert_chain) = config.cert_chain.as_ref() {
451 for cert in X509::stack_from_pem(cert_chain.as_bytes())?.drain(..).rev() {
452 builder.cert_store_mut().add_cert(cert)?;
453 }
454 }
455 s.into_openssl(&builder.build(), domain)
456 }
457 } else if #[cfg(feature = "native-tls")] {
458 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
459 use native_tls::Certificate;
460
461 let mut builder = NativeTlsConnector::builder();
462 if let Some(identity) = config.identity {
463 let native_identity = match identity {
464 Identity::PKCS8 { pem, key } => native_tls::Identity::from_pkcs8(pem, key),
465 Identity::PKCS12 { der, password } => native_tls::Identity::from_pkcs12(der, password),
466 };
467 builder.identity(native_identity.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?);
468 }
469 if let Some(cert_chain) = config.cert_chain {
470 let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
471 for cert in rustls_pemfile::certs(&mut cert_chain).collect::<Result<Vec<_>, _>>()? {
472 builder.add_root_certificate(Certificate::from_der(&cert[..]).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?);
473 }
474 }
475 s.into_native_tls(&builder.build().map_err(|e| io::Error::new(io::ErrorKind::Other, e))?, domain)
476 }
477 } else {
478 fn into_tls_impl(s: TcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> HandshakeResult {
479 Ok(s.into_plain()?)
480 }
481 }
482}
483
484impl TryFrom<StdTcpStream> for TcpStream {
485 type Error = io::Error;
486
487 fn try_from(s: StdTcpStream) -> io::Result<Self> {
488 let mut this = TcpStream::Plain(s, false);
489 this.try_connect()?;
490 Ok(this)
491 }
492}
493
494#[cfg(feature = "native-tls")]
495impl From<NativeTlsStream> for TcpStream {
496 fn from(s: NativeTlsStream) -> Self {
497 TcpStream::NativeTls(Box::new(s))
498 }
499}
500
501#[cfg(feature = "openssl")]
502impl From<OpenSslStream> for TcpStream {
503 fn from(s: OpenSslStream) -> Self {
504 TcpStream::OpenSsl(Box::new(s))
505 }
506}
507
508#[cfg(feature = "rustls-common")]
509impl From<RustlsStream> for TcpStream {
510 fn from(s: RustlsStream) -> Self {
511 TcpStream::Rustls(Box::new(s))
512 }
513}
514
515impl TcpStream {
516 pub fn is_readable(&self) -> io::Result<()> {
518 self.deref().read(&mut []).map(|_| ())
519 }
520
521 pub fn is_writable(&self) -> io::Result<()> {
523 self.deref().write(&[]).map(|_| ())
524 }
525}
526
527impl Deref for TcpStream {
528 type Target = StdTcpStream;
529
530 fn deref(&self) -> &Self::Target {
531 match self {
532 TcpStream::Plain(plain, _) => plain,
533 #[cfg(feature = "native-tls")]
534 TcpStream::NativeTls(tls) => tls.get_ref(),
535 #[cfg(feature = "openssl")]
536 TcpStream::OpenSsl(tls) => tls.get_ref(),
537 #[cfg(feature = "rustls-common")]
538 TcpStream::Rustls(tls) => tls.get_ref(),
539 }
540 }
541}
542
543impl DerefMut for TcpStream {
544 fn deref_mut(&mut self) -> &mut Self::Target {
545 match self {
546 TcpStream::Plain(plain, _) => plain,
547 #[cfg(feature = "native-tls")]
548 TcpStream::NativeTls(tls) => tls.get_mut(),
549 #[cfg(feature = "openssl")]
550 TcpStream::OpenSsl(tls) => tls.get_mut(),
551 #[cfg(feature = "rustls-common")]
552 TcpStream::Rustls(tls) => tls.get_mut(),
553 }
554 }
555}
556
557macro_rules! fwd_impl {
558 ($self:ident, $method:ident, $($args:expr),*) => {
559 match $self {
560 TcpStream::Plain(plain, _) => plain.$method($($args),*),
561 #[cfg(feature = "native-tls")]
562 TcpStream::NativeTls(tls) => tls.$method($($args),*),
563 #[cfg(feature = "openssl")]
564 TcpStream::OpenSsl(tls) => tls.$method($($args),*),
565 #[cfg(feature = "rustls-common")]
566 TcpStream::Rustls(tls) => tls.$method($($args),*),
567 }
568 };
569}
570
571impl Read for TcpStream {
572 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
573 fwd_impl!(self, read, buf)
574 }
575
576 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
577 fwd_impl!(self, read_vectored, bufs)
578 }
579
580 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
581 fwd_impl!(self, read_to_end, buf)
582 }
583
584 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
585 fwd_impl!(self, read_to_string, buf)
586 }
587
588 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
589 fwd_impl!(self, read_exact, buf)
590 }
591}
592
593impl Write for TcpStream {
594 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
595 fwd_impl!(self, write, buf)
596 }
597
598 fn flush(&mut self) -> io::Result<()> {
599 fwd_impl!(self, flush,)
600 }
601
602 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
603 fwd_impl!(self, write_vectored, bufs)
604 }
605
606 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
607 fwd_impl!(self, write_all, buf)
608 }
609
610 fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
611 fwd_impl!(self, write_fmt, fmt)
612 }
613}
614
615impl fmt::Debug for TcpStream {
616 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
617 f.debug_struct("TcpStream")
618 .field("inner", self.deref())
619 .finish()
620 }
621}
622
623#[allow(clippy::large_enum_variant)]
625#[derive(Debug)]
626pub enum MidHandshakeTlsStream {
627 Plain(TcpStream),
629 #[cfg(feature = "native-tls")]
630 NativeTls(NativeTlsMidHandshakeTlsStream),
632 #[cfg(feature = "openssl")]
633 Openssl(OpenSslMidHandshakeTlsStream),
635 #[cfg(feature = "rustls-common")]
636 Rustls(RustlsMidHandshakeTlsStream),
638}
639
640impl MidHandshakeTlsStream {
641 #[must_use]
643 pub fn get_ref(&self) -> &TcpStream {
644 match self {
645 MidHandshakeTlsStream::Plain(mid) => mid,
646 #[cfg(feature = "native-tls")]
647 MidHandshakeTlsStream::NativeTls(mid) => mid.get_ref(),
648 #[cfg(feature = "openssl")]
649 MidHandshakeTlsStream::Openssl(mid) => mid.get_ref(),
650 #[cfg(feature = "rustls-common")]
651 MidHandshakeTlsStream::Rustls(mid) => mid.get_ref(),
652 }
653 }
654
655 #[must_use]
657 pub fn get_mut(&mut self) -> &mut TcpStream {
658 match self {
659 MidHandshakeTlsStream::Plain(mid) => mid,
660 #[cfg(feature = "native-tls")]
661 MidHandshakeTlsStream::NativeTls(mid) => mid.get_mut(),
662 #[cfg(feature = "openssl")]
663 MidHandshakeTlsStream::Openssl(mid) => mid.get_mut(),
664 #[cfg(feature = "rustls-common")]
665 MidHandshakeTlsStream::Rustls(mid) => mid.get_mut(),
666 }
667 }
668
669 pub fn handshake(self) -> HandshakeResult {
671 Ok(match self {
672 MidHandshakeTlsStream::Plain(mut mid) => {
673 if !mid.try_connect()? {
674 return Err(HandshakeError::WouldBlock(mid.into()));
675 }
676 mid
677 }
678 #[cfg(feature = "native-tls")]
679 MidHandshakeTlsStream::NativeTls(mut mid) => {
680 if !mid.get_mut().try_connect()? {
681 return Err(HandshakeError::WouldBlock(mid.into()));
682 }
683 mid.handshake()?.into()
684 }
685 #[cfg(feature = "openssl")]
686 MidHandshakeTlsStream::Openssl(mut mid) => {
687 if !mid.get_mut().try_connect()? {
688 return Err(HandshakeError::WouldBlock(mid.into()));
689 }
690 mid.handshake()?.into()
691 }
692 #[cfg(feature = "rustls-common")]
693 MidHandshakeTlsStream::Rustls(mut mid) => {
694 if !mid.get_mut().try_connect()? {
695 return Err(HandshakeError::WouldBlock(mid.into()));
696 }
697 mid.handshake()?.into()
698 }
699 })
700 }
701}
702
703impl From<TcpStream> for MidHandshakeTlsStream {
704 fn from(mid: TcpStream) -> Self {
705 MidHandshakeTlsStream::Plain(mid)
706 }
707}
708
709#[cfg(feature = "native-tls")]
710impl From<NativeTlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
711 fn from(mid: NativeTlsMidHandshakeTlsStream) -> Self {
712 MidHandshakeTlsStream::NativeTls(mid)
713 }
714}
715
716#[cfg(feature = "openssl")]
717impl From<OpenSslMidHandshakeTlsStream> for MidHandshakeTlsStream {
718 fn from(mid: OpenSslMidHandshakeTlsStream) -> Self {
719 MidHandshakeTlsStream::Openssl(mid)
720 }
721}
722
723#[cfg(feature = "rustls-common")]
724impl From<RustlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
725 fn from(mid: RustlsMidHandshakeTlsStream) -> Self {
726 MidHandshakeTlsStream::Rustls(mid)
727 }
728}
729
730impl fmt::Display for MidHandshakeTlsStream {
731 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
732 f.write_str("MidHandshakeTlsStream")
733 }
734}
735
736#[allow(clippy::large_enum_variant)]
738#[derive(Debug)]
739pub enum HandshakeError {
740 WouldBlock(MidHandshakeTlsStream),
742 Failure(io::Error),
744}
745
746impl HandshakeError {
747 pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
749 match self {
750 Self::WouldBlock(mid) => Ok(mid),
751 Self::Failure(error) => Err(error),
752 }
753 }
754}
755
756impl fmt::Display for HandshakeError {
757 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
758 match self {
759 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
760 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {}", err)),
761 }
762 }
763}
764
765impl Error for HandshakeError {
766 fn source(&self) -> Option<&(dyn Error + 'static)> {
767 match self {
768 HandshakeError::Failure(err) => Some(err),
769 _ => None,
770 }
771 }
772}
773
774#[cfg(feature = "native-tls")]
775impl From<NativeTlsHandshakeError> for HandshakeError {
776 fn from(error: NativeTlsHandshakeError) -> Self {
777 match error {
778 native_tls::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
779 native_tls::HandshakeError::Failure(failure) => {
780 HandshakeError::Failure(io::Error::new(io::ErrorKind::Other, failure))
781 }
782 }
783 }
784}
785
786#[cfg(feature = "openssl")]
787impl From<OpenSslHandshakeError> for HandshakeError {
788 fn from(error: OpenSslHandshakeError) -> Self {
789 match error {
790 openssl::ssl::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
791 openssl::ssl::HandshakeError::Failure(failure) => {
792 HandshakeError::Failure(io::Error::new(io::ErrorKind::Other, failure.into_error()))
793 }
794 openssl::ssl::HandshakeError::SetupFailure(failure) => failure.into(),
795 }
796 }
797}
798
799#[cfg(feature = "openssl")]
800impl From<OpenSslErrorStack> for HandshakeError {
801 fn from(error: OpenSslErrorStack) -> Self {
802 Self::Failure(error.into())
803 }
804}
805
806#[cfg(feature = "rustls-common")]
807impl From<RustlsHandshakeError> for HandshakeError {
808 fn from(error: RustlsHandshakeError) -> Self {
809 match error {
810 rustls_connector::HandshakeError::WouldBlock(mid) => {
811 HandshakeError::WouldBlock((*mid).into())
812 }
813 rustls_connector::HandshakeError::Failure(failure) => HandshakeError::Failure(failure),
814 }
815 }
816}
817
818impl From<io::Error> for HandshakeError {
819 fn from(err: io::Error) -> Self {
820 HandshakeError::Failure(err)
821 }
822}
823
824#[cfg(unix)]
825mod sys {
826 use crate::TcpStream;
827 use std::{
828 net::TcpStream as StdTcpStream,
829 os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd},
830 };
831
832 impl AsFd for TcpStream {
833 fn as_fd(&self) -> BorrowedFd<'_> {
834 <StdTcpStream as AsFd>::as_fd(self)
835 }
836 }
837
838 impl AsRawFd for TcpStream {
839 fn as_raw_fd(&self) -> RawFd {
840 <StdTcpStream as AsRawFd>::as_raw_fd(self)
841 }
842 }
843
844 impl AsRawFd for &TcpStream {
845 fn as_raw_fd(&self) -> RawFd {
846 <StdTcpStream as AsRawFd>::as_raw_fd(self)
847 }
848 }
849
850 impl FromRawFd for TcpStream {
851 unsafe fn from_raw_fd(fd: RawFd) -> Self {
852 Self::Plain(unsafe { StdTcpStream::from_raw_fd(fd) }, false)
853 }
854 }
855}
856
857#[cfg(windows)]
858mod sys {
859 use crate::TcpStream;
860 use std::{
861 net::TcpStream as StdTcpStream,
862 os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, RawSocket},
863 };
864
865 impl AsSocket for TcpStream {
866 fn as_socket(&self) -> BorrowedSocket<'_> {
867 <StdTcpStream as AsSocket>::as_socket(self)
868 }
869 }
870
871 impl AsRawSocket for TcpStream {
872 fn as_raw_socket(&self) -> RawSocket {
873 <StdTcpStream as AsRawSocket>::as_raw_socket(self)
874 }
875 }
876
877 impl AsRawSocket for &TcpStream {
878 fn as_raw_socket(&self) -> RawSocket {
879 <StdTcpStream as AsRawSocket>::as_raw_socket(self)
880 }
881 }
882
883 impl FromRawSocket for TcpStream {
884 unsafe fn from_raw_socket(socket: RawSocket) -> Self {
885 Self::Plain(unsafe { StdTcpStream::from_raw_socket(socket) }, false)
886 }
887 }
888}