1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4 clippy::must_use_candidate,
5 clippy::unwrap_in_result,
6 clippy::panic_in_result_fn
7)]
8#![allow(clippy::large_enum_variant, clippy::result_large_err)]
9
10use cfg_if::cfg_if;
106use std::{
107 convert::TryFrom,
108 error::Error,
109 fmt,
110 io::{self, IoSlice, IoSliceMut, Read, Write},
111 net::{TcpStream as StdTcpStream, ToSocketAddrs},
112 ops::{Deref, DerefMut},
113 time::Duration,
114};
115
116#[cfg(feature = "rustls")]
117mod rustls_impl;
118#[cfg(feature = "rustls")]
119pub use rustls_impl::*;
120
121#[cfg(feature = "native-tls")]
122mod native_tls_impl;
123#[cfg(feature = "native-tls")]
124pub use native_tls_impl::*;
125
126#[cfg(feature = "openssl")]
127mod openssl_impl;
128#[cfg(feature = "openssl")]
129pub use openssl_impl::*;
130
131macro_rules! fwd_impl {
132 ($self:ident, $method:ident, $($args:expr),*) => {
133 match $self {
134 Self::Plain(plain) => plain.$method($($args),*),
135 #[cfg(feature = "native-tls")]
136 Self::NativeTls(tls) => tls.$method($($args),*),
137 #[cfg(feature = "openssl")]
138 Self::Openssl(tls) => tls.$method($($args),*),
139 #[cfg(feature = "rustls")]
140 Self::Rustls(tls) => tls.$method($($args),*),
141 }
142 };
143}
144
145macro_rules! fwd_pin_impl {
146 ($self:ident, $method:ident, $($args:expr),*) => {
147 match $self.get_mut() {
148 Self::Plain(plain) => Pin::new(plain).$method($($args),*),
149 #[cfg(feature = "native-tls-futures")]
150 Self::NativeTls(tls) => Pin::new(tls).$method($($args),*),
151 #[cfg(feature = "openssl-futures")]
152 Self::Openssl(tls) => Pin::new(tls).$method($($args),*),
153 #[cfg(feature = "rustls-futures")]
154 Self::Rustls(tls) => Pin::new(tls).$method($($args),*),
155 }
156 };
157}
158
159#[cfg(feature = "futures")]
160mod futures;
161#[cfg(feature = "futures")]
162pub use futures::*;
163
164#[non_exhaustive]
166pub enum TcpStream {
167 Plain(StdTcpStream),
169 #[cfg(feature = "native-tls")]
170 NativeTls(NativeTlsStream),
172 #[cfg(feature = "openssl")]
173 Openssl(OpensslStream),
175 #[cfg(feature = "rustls")]
176 Rustls(RustlsStream),
178}
179
180#[derive(Default, Debug, PartialEq)]
182pub struct TLSConfig<'data, 'key, 'chain> {
183 pub identity: Option<Identity<'data, 'key>>,
185 pub cert_chain: Option<&'chain str>,
187}
188
189#[derive(Clone, Default, Debug, PartialEq)]
191pub struct OwnedTLSConfig {
192 pub identity: Option<OwnedIdentity>,
194 pub cert_chain: Option<String>,
196}
197
198impl OwnedTLSConfig {
199 #[must_use]
201 pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
202 TLSConfig {
203 identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
204 cert_chain: self.cert_chain.as_deref(),
205 }
206 }
207}
208
209#[derive(Debug, PartialEq)]
213pub enum Identity<'data, 'key> {
214 PKCS12 {
216 der: &'data [u8],
218 password: &'key str,
220 },
221 PKCS8 {
223 pem: &'data [u8],
225 key: &'key [u8],
227 },
228}
229
230#[derive(Clone, Debug, PartialEq)]
234pub enum OwnedIdentity {
235 PKCS12 {
237 der: Vec<u8>,
239 password: String,
241 },
242 PKCS8 {
244 pem: Vec<u8>,
246 key: Vec<u8>,
248 },
249}
250
251impl OwnedIdentity {
252 #[must_use]
254 pub fn as_ref(&self) -> Identity<'_, '_> {
255 match self {
256 Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
257 Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
258 }
259 }
260}
261
262pub type HandshakeResult = Result<TcpStream, HandshakeError>;
264
265impl TcpStream {
266 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
268 connect_std(addr, None).and_then(Self::try_from)
269 }
270
271 pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
273 connect_std(addr, Some(timeout)).and_then(Self::try_from)
274 }
275
276 pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
278 Self::try_from(stream)
279 }
280
281 pub fn is_readable(&self) -> io::Result<()> {
283 self.deref().read(&mut []).map(|_| ())
284 }
285
286 pub fn is_writable(&self) -> io::Result<()> {
288 is_writable(self.deref())
289 }
290
291 pub fn try_connect(&mut self) -> io::Result<bool> {
296 try_connect(self)
297 }
298
299 pub fn into_tls(
307 self,
308 domain: &str,
309 config: TLSConfig<'_, '_, '_>,
310 ) -> Result<Self, HandshakeError> {
311 into_tls_impl(self, domain, config)
312 }
313
314 #[cfg(feature = "native-tls")]
315 pub fn into_native_tls(
317 self,
318 connector: &NativeTlsConnector,
319 domain: &str,
320 ) -> Result<Self, HandshakeError> {
321 Ok(connector.connect(domain, self.into_plain()?)?.into())
322 }
323
324 #[cfg(feature = "openssl")]
325 pub fn into_openssl(
327 self,
328 connector: &OpensslConnector,
329 domain: &str,
330 ) -> Result<Self, HandshakeError> {
331 Ok(connector.connect(domain, self.into_plain()?)?.into())
332 }
333
334 #[cfg(feature = "rustls")]
335 pub fn into_rustls(
337 self,
338 connector: &RustlsConnector,
339 domain: &str,
340 ) -> Result<Self, HandshakeError> {
341 Ok(connector.connect(domain, self.into_plain()?)?.into())
342 }
343
344 #[allow(irrefutable_let_patterns)]
345 fn into_plain(self) -> Result<StdTcpStream, io::Error> {
346 if let Self::Plain(plain) = self {
347 Ok(plain)
348 } else {
349 Err(io::Error::new(
350 io::ErrorKind::AlreadyExists,
351 "already a TLS stream",
352 ))
353 }
354 }
355}
356
357fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
358 if let Some(timeout) = timeout {
359 let addrs = addr.to_socket_addrs()?;
360 let mut err = None;
361 for addr in addrs {
362 match StdTcpStream::connect_timeout(&addr, timeout) {
363 Ok(stream) => return Ok(stream),
364 Err(error) => err = Some(error),
365 }
366 }
367 Err(err.unwrap_or_else(|| {
368 io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
369 }))
370 } else {
371 StdTcpStream::connect(addr)
372 }
373}
374
375fn try_connect(stream: &mut StdTcpStream) -> io::Result<bool> {
376 match is_writable(stream) {
377 Ok(()) => Ok(true),
378 Err(err)
379 if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected].contains(&err.kind()) =>
380 {
381 Ok(false)
382 }
383 Err(err) => Err(err),
384 }
385}
386
387fn is_writable(mut stream: &StdTcpStream) -> io::Result<()> {
388 stream.write(&[]).map(|_| ())
389}
390
391fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
392 cfg_if! {
393 if #[cfg(feature = "rustls")] {
394 into_rustls_impl(s, domain, config)
395 } else if #[cfg(feature = "openssl")] {
396 into_openssl_impl(s, domain, config)
397 } else if #[cfg(feature = "native-tls")] {
398 into_native_tls_impl(s, domain, config)
399 } else {
400 let _ = (domain, config);
401 Ok(TcpStream::Plain(s.into_plain()?))
402 }
403 }
404}
405
406impl TryFrom<StdTcpStream> for TcpStream {
407 type Error = io::Error;
408
409 fn try_from(s: StdTcpStream) -> io::Result<Self> {
410 s.set_nodelay(true)?;
411 let mut this = Self::Plain(s);
412 this.try_connect()?;
413 Ok(this)
414 }
415}
416
417impl Deref for TcpStream {
418 type Target = StdTcpStream;
419
420 fn deref(&self) -> &Self::Target {
421 match self {
422 Self::Plain(plain) => plain,
423 #[cfg(feature = "native-tls")]
424 Self::NativeTls(tls) => tls.get_ref(),
425 #[cfg(feature = "openssl")]
426 Self::Openssl(tls) => tls.get_ref(),
427 #[cfg(feature = "rustls")]
428 Self::Rustls(tls) => tls.get_ref(),
429 }
430 }
431}
432
433impl DerefMut for TcpStream {
434 fn deref_mut(&mut self) -> &mut Self::Target {
435 match self {
436 Self::Plain(plain) => plain,
437 #[cfg(feature = "native-tls")]
438 Self::NativeTls(tls) => tls.get_mut(),
439 #[cfg(feature = "openssl")]
440 Self::Openssl(tls) => tls.get_mut(),
441 #[cfg(feature = "rustls")]
442 Self::Rustls(tls) => tls.get_mut(),
443 }
444 }
445}
446
447impl Read for TcpStream {
448 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
449 fwd_impl!(self, read, buf)
450 }
451
452 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
453 fwd_impl!(self, read_vectored, bufs)
454 }
455
456 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
457 fwd_impl!(self, read_to_end, buf)
458 }
459
460 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
461 fwd_impl!(self, read_to_string, buf)
462 }
463
464 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
465 fwd_impl!(self, read_exact, buf)
466 }
467}
468
469impl Write for TcpStream {
470 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
471 fwd_impl!(self, write, buf)
472 }
473
474 fn flush(&mut self) -> io::Result<()> {
475 fwd_impl!(self, flush,)
476 }
477
478 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
479 fwd_impl!(self, write_vectored, bufs)
480 }
481
482 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
483 fwd_impl!(self, write_all, buf)
484 }
485
486 fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
487 fwd_impl!(self, write_fmt, fmt)
488 }
489}
490
491impl fmt::Debug for TcpStream {
492 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
493 f.debug_struct("TcpStream")
494 .field("inner", self.deref())
495 .finish()
496 }
497}
498
499#[derive(Debug)]
501pub enum MidHandshakeTlsStream {
502 Plain(TcpStream),
504 #[cfg(feature = "native-tls")]
505 NativeTls(NativeTlsMidHandshakeTlsStream),
507 #[cfg(feature = "openssl")]
508 Openssl(OpensslMidHandshakeTlsStream),
510 #[cfg(feature = "rustls")]
511 Rustls(RustlsMidHandshakeTlsStream),
513}
514
515impl MidHandshakeTlsStream {
516 #[must_use]
518 pub fn get_ref(&self) -> &StdTcpStream {
519 match self {
520 Self::Plain(mid) => mid,
521 #[cfg(feature = "native-tls")]
522 Self::NativeTls(mid) => mid.get_ref(),
523 #[cfg(feature = "openssl")]
524 Self::Openssl(mid) => mid.get_ref(),
525 #[cfg(feature = "rustls")]
526 Self::Rustls(mid) => mid.get_ref(),
527 }
528 }
529
530 #[must_use]
532 pub fn get_mut(&mut self) -> &mut StdTcpStream {
533 match self {
534 Self::Plain(mid) => mid,
535 #[cfg(feature = "native-tls")]
536 Self::NativeTls(mid) => mid.get_mut(),
537 #[cfg(feature = "openssl")]
538 Self::Openssl(mid) => mid.get_mut(),
539 #[cfg(feature = "rustls")]
540 Self::Rustls(mid) => mid.get_mut(),
541 }
542 }
543
544 pub fn handshake(mut self) -> HandshakeResult {
549 if !try_connect(self.get_mut())? {
550 return Err(HandshakeError::WouldBlock(self));
551 }
552
553 Ok(match self {
554 Self::Plain(mid) => mid,
555 #[cfg(feature = "native-tls")]
556 Self::NativeTls(mid) => mid.handshake()?.into(),
557 #[cfg(feature = "openssl")]
558 Self::Openssl(mid) => mid.handshake()?.into(),
559 #[cfg(feature = "rustls")]
560 Self::Rustls(mid) => mid.handshake()?.into(),
561 })
562 }
563}
564
565impl From<TcpStream> for MidHandshakeTlsStream {
566 fn from(mid: TcpStream) -> Self {
567 Self::Plain(mid)
568 }
569}
570
571impl fmt::Display for MidHandshakeTlsStream {
572 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
573 f.write_str("MidHandshakeTlsStream")
574 }
575}
576
577#[derive(Debug)]
579pub enum HandshakeError {
580 WouldBlock(MidHandshakeTlsStream),
582 Failure(io::Error),
584}
585
586impl HandshakeError {
587 pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
589 match self {
590 Self::WouldBlock(mid) => Ok(mid),
591 Self::Failure(error) => Err(error),
592 }
593 }
594}
595
596impl fmt::Display for HandshakeError {
597 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
598 match self {
599 Self::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
600 Self::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
601 }
602 }
603}
604
605impl Error for HandshakeError {
606 fn source(&self) -> Option<&(dyn Error + 'static)> {
607 match self {
608 Self::Failure(err) => Some(err),
609 _ => None,
610 }
611 }
612}
613
614impl From<io::Error> for HandshakeError {
615 fn from(err: io::Error) -> Self {
616 Self::Failure(err)
617 }
618}
619
620mod sys;