1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4 clippy::must_use_candidate,
5 clippy::unwrap_in_result,
6 clippy::panic_in_result_fn
7)]
8#![allow(clippy::large_enum_variant, clippy::result_large_err)]
9
10use cfg_if::cfg_if;
106use std::{
107 convert::TryFrom,
108 error::Error,
109 fmt,
110 io::{self, IoSlice, IoSliceMut, Read, Write},
111 net::{TcpStream as StdTcpStream, ToSocketAddrs},
112 ops::{Deref, DerefMut},
113 time::Duration,
114};
115
116#[cfg(feature = "rustls")]
117mod rustls_impl;
118#[cfg(feature = "rustls")]
119pub use rustls_impl::*;
120
121#[cfg(feature = "native-tls")]
122mod native_tls_impl;
123#[cfg(feature = "native-tls")]
124pub use native_tls_impl::*;
125
126#[cfg(feature = "openssl")]
127mod openssl_impl;
128#[cfg(feature = "openssl")]
129pub use openssl_impl::*;
130
131macro_rules! fwd_impl {
132 ($self:ident, $method:ident, $($args:expr),*) => {
133 match $self {
134 Self::Plain(plain) => plain.$method($($args),*),
135 #[cfg(feature = "native-tls")]
136 Self::NativeTls(tls) => tls.$method($($args),*),
137 #[cfg(feature = "openssl")]
138 Self::Openssl(tls) => tls.$method($($args),*),
139 #[cfg(feature = "rustls")]
140 Self::Rustls(tls) => tls.$method($($args),*),
141 }
142 };
143}
144
145#[cfg(feature = "futures")]
146macro_rules! fwd_pin_impl {
147 ($self:ident, $method:ident, $($args:expr),*) => {
148 match $self.get_mut() {
149 Self::Plain(plain) => Pin::new(plain).$method($($args),*),
150 #[cfg(feature = "native-tls-futures")]
151 Self::NativeTls(tls) => Pin::new(tls).$method($($args),*),
152 #[cfg(feature = "openssl-futures")]
153 Self::Openssl(tls) => Pin::new(tls).$method($($args),*),
154 #[cfg(feature = "rustls-futures")]
155 Self::Rustls(tls) => Pin::new(tls).$method($($args),*),
156 }
157 };
158}
159
160#[cfg(feature = "futures")]
161mod futures;
162#[cfg(feature = "futures")]
163pub use futures::*;
164
165#[non_exhaustive]
167pub enum TcpStream {
168 Plain(StdTcpStream),
170 #[cfg(feature = "native-tls")]
171 NativeTls(NativeTlsStream),
173 #[cfg(feature = "openssl")]
174 Openssl(OpensslStream),
176 #[cfg(feature = "rustls")]
177 Rustls(RustlsStream),
179}
180
181#[derive(Default, Debug, PartialEq)]
183pub struct TLSConfig<'data, 'key, 'chain> {
184 pub identity: Option<Identity<'data, 'key>>,
186 pub cert_chain: Option<&'chain str>,
188}
189
190#[derive(Clone, Default, Debug, PartialEq)]
192pub struct OwnedTLSConfig {
193 pub identity: Option<OwnedIdentity>,
195 pub cert_chain: Option<String>,
197}
198
199impl OwnedTLSConfig {
200 #[must_use]
202 pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
203 TLSConfig {
204 identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
205 cert_chain: self.cert_chain.as_deref(),
206 }
207 }
208}
209
210#[derive(Debug, PartialEq)]
214pub enum Identity<'data, 'key> {
215 PKCS12 {
217 der: &'data [u8],
219 password: &'key str,
221 },
222 PKCS8 {
224 pem: &'data [u8],
226 key: &'key [u8],
228 },
229}
230
231#[derive(Clone, Debug, PartialEq)]
235pub enum OwnedIdentity {
236 PKCS12 {
238 der: Vec<u8>,
240 password: String,
242 },
243 PKCS8 {
245 pem: Vec<u8>,
247 key: Vec<u8>,
249 },
250}
251
252impl OwnedIdentity {
253 #[must_use]
255 pub fn as_ref(&self) -> Identity<'_, '_> {
256 match self {
257 Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
258 Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
259 }
260 }
261}
262
263pub type HandshakeResult = Result<TcpStream, HandshakeError>;
265
266impl TcpStream {
267 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
269 connect_std(addr, None).and_then(Self::try_from)
270 }
271
272 pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
274 connect_std(addr, Some(timeout)).and_then(Self::try_from)
275 }
276
277 pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
279 Self::try_from(stream)
280 }
281
282 pub fn is_readable(&self) -> io::Result<()> {
284 self.deref().read(&mut []).map(|_| ())
285 }
286
287 pub fn is_writable(&self) -> io::Result<()> {
289 is_writable(self.deref())
290 }
291
292 pub fn try_connect(&mut self) -> io::Result<bool> {
297 try_connect(self)
298 }
299
300 pub fn into_tls(
308 self,
309 domain: &str,
310 config: TLSConfig<'_, '_, '_>,
311 ) -> Result<Self, HandshakeError> {
312 into_tls_impl(self, domain, config)
313 }
314
315 #[cfg(feature = "native-tls")]
316 pub fn into_native_tls(
318 self,
319 connector: &NativeTlsConnector,
320 domain: &str,
321 ) -> Result<Self, HandshakeError> {
322 Ok(connector.connect(domain, self.into_plain()?)?.into())
323 }
324
325 #[cfg(feature = "openssl")]
326 pub fn into_openssl(
328 self,
329 connector: &OpensslConnector,
330 domain: &str,
331 ) -> Result<Self, HandshakeError> {
332 Ok(connector.connect(domain, self.into_plain()?)?.into())
333 }
334
335 #[cfg(feature = "rustls")]
336 pub fn into_rustls(
338 self,
339 connector: &RustlsConnector,
340 domain: &str,
341 ) -> Result<Self, HandshakeError> {
342 Ok(connector.connect(domain, self.into_plain()?)?.into())
343 }
344
345 #[allow(irrefutable_let_patterns)]
346 fn into_plain(self) -> Result<StdTcpStream, io::Error> {
347 if let Self::Plain(plain) = self {
348 Ok(plain)
349 } else {
350 Err(io::Error::new(
351 io::ErrorKind::AlreadyExists,
352 "already a TLS stream",
353 ))
354 }
355 }
356}
357
358fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
359 if let Some(timeout) = timeout {
360 let addrs = addr.to_socket_addrs()?;
361 let mut err = None;
362 for addr in addrs {
363 match StdTcpStream::connect_timeout(&addr, timeout) {
364 Ok(stream) => return Ok(stream),
365 Err(error) => err = Some(error),
366 }
367 }
368 Err(err.unwrap_or_else(|| {
369 io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
370 }))
371 } else {
372 StdTcpStream::connect(addr)
373 }
374}
375
376fn try_connect(stream: &mut StdTcpStream) -> io::Result<bool> {
377 match is_writable(stream) {
378 Ok(()) => Ok(true),
379 Err(err)
380 if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected].contains(&err.kind()) =>
381 {
382 Ok(false)
383 }
384 Err(err) => Err(err),
385 }
386}
387
388fn is_writable(mut stream: &StdTcpStream) -> io::Result<()> {
389 stream.write(&[]).map(|_| ())
390}
391
392fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
393 cfg_if! {
394 if #[cfg(feature = "rustls")] {
395 into_rustls_impl(s, domain, config)
396 } else if #[cfg(feature = "openssl")] {
397 into_openssl_impl(s, domain, config)
398 } else if #[cfg(feature = "native-tls")] {
399 into_native_tls_impl(s, domain, config)
400 } else {
401 let _ = (domain, config);
402 Ok(TcpStream::Plain(s.into_plain()?))
403 }
404 }
405}
406
407impl TryFrom<StdTcpStream> for TcpStream {
408 type Error = io::Error;
409
410 fn try_from(s: StdTcpStream) -> io::Result<Self> {
411 s.set_nodelay(true)?;
412 let mut this = Self::Plain(s);
413 this.try_connect()?;
414 Ok(this)
415 }
416}
417
418impl Deref for TcpStream {
419 type Target = StdTcpStream;
420
421 fn deref(&self) -> &Self::Target {
422 match self {
423 Self::Plain(plain) => plain,
424 #[cfg(feature = "native-tls")]
425 Self::NativeTls(tls) => tls.get_ref(),
426 #[cfg(feature = "openssl")]
427 Self::Openssl(tls) => tls.get_ref(),
428 #[cfg(feature = "rustls")]
429 Self::Rustls(tls) => tls.get_ref(),
430 }
431 }
432}
433
434impl DerefMut for TcpStream {
435 fn deref_mut(&mut self) -> &mut Self::Target {
436 match self {
437 Self::Plain(plain) => plain,
438 #[cfg(feature = "native-tls")]
439 Self::NativeTls(tls) => tls.get_mut(),
440 #[cfg(feature = "openssl")]
441 Self::Openssl(tls) => tls.get_mut(),
442 #[cfg(feature = "rustls")]
443 Self::Rustls(tls) => tls.get_mut(),
444 }
445 }
446}
447
448impl Read for TcpStream {
449 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
450 fwd_impl!(self, read, buf)
451 }
452
453 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
454 fwd_impl!(self, read_vectored, bufs)
455 }
456
457 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
458 fwd_impl!(self, read_to_end, buf)
459 }
460
461 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
462 fwd_impl!(self, read_to_string, buf)
463 }
464
465 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
466 fwd_impl!(self, read_exact, buf)
467 }
468}
469
470impl Write for TcpStream {
471 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
472 fwd_impl!(self, write, buf)
473 }
474
475 fn flush(&mut self) -> io::Result<()> {
476 fwd_impl!(self, flush,)
477 }
478
479 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
480 fwd_impl!(self, write_vectored, bufs)
481 }
482
483 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
484 fwd_impl!(self, write_all, buf)
485 }
486
487 fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
488 fwd_impl!(self, write_fmt, fmt)
489 }
490}
491
492impl fmt::Debug for TcpStream {
493 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
494 f.debug_struct("TcpStream")
495 .field("inner", self.deref())
496 .finish()
497 }
498}
499
500#[derive(Debug)]
502pub enum MidHandshakeTlsStream {
503 Plain(TcpStream),
505 #[cfg(feature = "native-tls")]
506 NativeTls(NativeTlsMidHandshakeTlsStream),
508 #[cfg(feature = "openssl")]
509 Openssl(OpensslMidHandshakeTlsStream),
511 #[cfg(feature = "rustls")]
512 Rustls(RustlsMidHandshakeTlsStream),
514}
515
516impl MidHandshakeTlsStream {
517 #[must_use]
519 pub fn get_ref(&self) -> &StdTcpStream {
520 match self {
521 Self::Plain(mid) => mid,
522 #[cfg(feature = "native-tls")]
523 Self::NativeTls(mid) => mid.get_ref(),
524 #[cfg(feature = "openssl")]
525 Self::Openssl(mid) => mid.get_ref(),
526 #[cfg(feature = "rustls")]
527 Self::Rustls(mid) => mid.get_ref(),
528 }
529 }
530
531 #[must_use]
533 pub fn get_mut(&mut self) -> &mut StdTcpStream {
534 match self {
535 Self::Plain(mid) => mid,
536 #[cfg(feature = "native-tls")]
537 Self::NativeTls(mid) => mid.get_mut(),
538 #[cfg(feature = "openssl")]
539 Self::Openssl(mid) => mid.get_mut(),
540 #[cfg(feature = "rustls")]
541 Self::Rustls(mid) => mid.get_mut(),
542 }
543 }
544
545 pub fn handshake(mut self) -> HandshakeResult {
550 if !try_connect(self.get_mut())? {
551 return Err(HandshakeError::WouldBlock(self));
552 }
553
554 Ok(match self {
555 Self::Plain(mid) => mid,
556 #[cfg(feature = "native-tls")]
557 Self::NativeTls(mid) => mid.handshake()?.into(),
558 #[cfg(feature = "openssl")]
559 Self::Openssl(mid) => mid.handshake()?.into(),
560 #[cfg(feature = "rustls")]
561 Self::Rustls(mid) => mid.handshake()?.into(),
562 })
563 }
564}
565
566impl From<TcpStream> for MidHandshakeTlsStream {
567 fn from(mid: TcpStream) -> Self {
568 Self::Plain(mid)
569 }
570}
571
572impl fmt::Display for MidHandshakeTlsStream {
573 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
574 f.write_str("MidHandshakeTlsStream")
575 }
576}
577
578#[derive(Debug)]
580pub enum HandshakeError {
581 WouldBlock(MidHandshakeTlsStream),
583 Failure(io::Error),
585}
586
587impl HandshakeError {
588 pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
590 match self {
591 Self::WouldBlock(mid) => Ok(mid),
592 Self::Failure(error) => Err(error),
593 }
594 }
595}
596
597impl fmt::Display for HandshakeError {
598 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
599 match self {
600 Self::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
601 Self::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
602 }
603 }
604}
605
606impl Error for HandshakeError {
607 fn source(&self) -> Option<&(dyn Error + 'static)> {
608 match self {
609 Self::Failure(err) => Some(err),
610 _ => None,
611 }
612 }
613}
614
615impl From<io::Error> for HandshakeError {
616 fn from(err: io::Error) -> Self {
617 Self::Failure(err)
618 }
619}
620
621mod sys;