1#![deny(missing_docs)]
2#![allow(clippy::result_large_err)]
3
4use cfg_if::cfg_if;
50use std::{
51 convert::TryFrom,
52 error::Error,
53 fmt,
54 io::{self, IoSlice, IoSliceMut, Read, Write},
55 net::{TcpStream as StdTcpStream, ToSocketAddrs},
56 ops::{Deref, DerefMut},
57 time::Duration,
58};
59
60#[cfg(feature = "rustls-common")]
61mod rustls_impl;
62#[cfg(feature = "rustls-common")]
63pub use rustls_impl::*;
64
65#[cfg(feature = "native-tls")]
66mod native_tls_impl;
67#[cfg(feature = "native-tls")]
68pub use native_tls_impl::*;
69
70#[cfg(feature = "openssl")]
71mod openssl_impl;
72#[cfg(feature = "openssl")]
73pub use openssl_impl::*;
74
75pub enum TcpStream {
77 Plain(StdTcpStream, bool),
79 #[cfg(feature = "native-tls")]
80 NativeTls(Box<NativeTlsStream>),
82 #[cfg(feature = "openssl")]
83 OpenSsl(Box<OpenSslStream>),
85 #[cfg(feature = "rustls-common")]
86 Rustls(Box<RustlsStream>),
88}
89
90#[derive(Default, Debug, PartialEq)]
92pub struct TLSConfig<'data, 'key, 'chain> {
93 pub identity: Option<Identity<'data, 'key>>,
95 pub cert_chain: Option<&'chain str>,
97}
98
99#[derive(Default, Debug, PartialEq)]
101pub struct OwnedTLSConfig {
102 pub identity: Option<OwnedIdentity>,
104 pub cert_chain: Option<String>,
106}
107
108impl OwnedTLSConfig {
109 #[must_use]
111 pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
112 TLSConfig {
113 identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
114 cert_chain: self.cert_chain.as_deref(),
115 }
116 }
117}
118
119#[derive(Debug, PartialEq)]
123pub enum Identity<'data, 'key> {
124 PKCS12 {
126 der: &'data [u8],
128 password: &'key str,
130 },
131 PKCS8 {
133 pem: &'data [u8],
135 key: &'key [u8],
137 },
138}
139
140#[derive(Debug, PartialEq)]
144pub enum OwnedIdentity {
145 PKCS12 {
147 der: Vec<u8>,
149 password: String,
151 },
152 PKCS8 {
154 pem: Vec<u8>,
156 key: Vec<u8>,
158 },
159}
160
161impl OwnedIdentity {
162 #[must_use]
164 pub fn as_ref(&self) -> Identity<'_, '_> {
165 match self {
166 Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
167 Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
168 }
169 }
170}
171
172pub type HandshakeResult = Result<TcpStream, HandshakeError>;
174
175impl TcpStream {
176 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
178 connect_std(addr, None).and_then(Self::try_from)
179 }
180
181 pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
183 connect_std(addr, Some(timeout)).and_then(Self::try_from)
184 }
185
186 pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
188 Self::try_from(stream)
189 }
190
191 #[must_use]
193 #[allow(irrefutable_let_patterns)]
194 pub fn is_connected(&self) -> bool {
195 if let Self::Plain(_, connected) = self {
196 *connected
197 } else {
198 true
199 }
200 }
201
202 pub fn is_readable(&self) -> io::Result<()> {
204 self.deref().read(&mut []).map(|_| ())
205 }
206
207 pub fn is_writable(&self) -> io::Result<()> {
209 self.deref().write(&[]).map(|_| ())
210 }
211
212 #[allow(irrefutable_let_patterns)]
217 pub fn try_connect(&mut self) -> io::Result<bool> {
218 if self.is_connected() {
219 return Ok(true);
220 }
221 match self.is_writable() {
222 Ok(()) => {
223 if let Self::Plain(_, connected) = self {
224 *connected = true;
225 }
226 Ok(true)
227 }
228 Err(err)
229 if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected]
230 .contains(&err.kind()) =>
231 {
232 Ok(false)
233 }
234 Err(err) => Err(err),
235 }
236 }
237
238 pub fn into_tls(
240 self,
241 domain: &str,
242 config: TLSConfig<'_, '_, '_>,
243 ) -> Result<Self, HandshakeError> {
244 into_tls_impl(self, domain, config)
245 }
246
247 #[cfg(feature = "native-tls")]
248 pub fn into_native_tls(
250 self,
251 connector: &NativeTlsConnector,
252 domain: &str,
253 ) -> Result<Self, HandshakeError> {
254 Ok(connector.connect(domain, self.into_plain()?)?.into())
255 }
256
257 #[cfg(feature = "openssl")]
258 pub fn into_openssl(
260 self,
261 connector: &OpenSslConnector,
262 domain: &str,
263 ) -> Result<Self, HandshakeError> {
264 Ok(connector.connect(domain, self.into_plain()?)?.into())
265 }
266
267 #[cfg(feature = "rustls-common")]
268 pub fn into_rustls(
270 self,
271 connector: &RustlsConnector,
272 domain: &str,
273 ) -> Result<Self, HandshakeError> {
274 Ok(connector.connect(domain, self.into_plain()?)?.into())
275 }
276
277 #[allow(irrefutable_let_patterns)]
278 fn into_plain(self) -> Result<TcpStream, io::Error> {
279 if let TcpStream::Plain(plain, connected) = self {
280 Ok(TcpStream::Plain(plain, connected))
281 } else {
282 Err(io::Error::new(
283 io::ErrorKind::AlreadyExists,
284 "already a TLS stream",
285 ))
286 }
287 }
288}
289
290fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
291 let stream = connect_std_raw(addr, timeout)?;
292 stream.set_nodelay(true)?;
293 Ok(stream)
294}
295
296fn connect_std_raw<A: ToSocketAddrs>(
297 addr: A,
298 timeout: Option<Duration>,
299) -> io::Result<StdTcpStream> {
300 if let Some(timeout) = timeout {
301 let addrs = addr.to_socket_addrs()?;
302 let mut err = None;
303 for addr in addrs {
304 match StdTcpStream::connect_timeout(&addr, timeout) {
305 Ok(stream) => return Ok(stream),
306 Err(error) => err = Some(error),
307 }
308 }
309 Err(err.unwrap_or_else(|| {
310 io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
311 }))
312 } else {
313 StdTcpStream::connect(addr)
314 }
315}
316
317cfg_if! {
318 if #[cfg(feature = "rustls-native-certs")] {
319 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
320 into_rustls_impl(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
321 }
322 } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
323 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
324 into_rustls_impl(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
325 }
326 } else if #[cfg(feature = "rustls-common")] {
327 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
328 into_rustls_impl(s, RustlsConnectorConfig::default(), domain, config)
329 }
330 } else if #[cfg(feature = "openssl")] {
331 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
332 into_openssl_impl(s, domain, config)
333 }
334 } else if #[cfg(feature = "native-tls")] {
335 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
336 into_native_tls_impl(s, domain, config)
337 }
338 } else {
339 fn into_tls_impl(s: TcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> HandshakeResult {
340 Ok(s.into_plain()?)
341 }
342 }
343}
344
345impl TryFrom<StdTcpStream> for TcpStream {
346 type Error = io::Error;
347
348 fn try_from(s: StdTcpStream) -> io::Result<Self> {
349 let mut this = TcpStream::Plain(s, false);
350 this.try_connect()?;
351 Ok(this)
352 }
353}
354
355impl Deref for TcpStream {
356 type Target = StdTcpStream;
357
358 fn deref(&self) -> &Self::Target {
359 match self {
360 TcpStream::Plain(plain, _) => plain,
361 #[cfg(feature = "native-tls")]
362 TcpStream::NativeTls(tls) => tls.get_ref(),
363 #[cfg(feature = "openssl")]
364 TcpStream::OpenSsl(tls) => tls.get_ref(),
365 #[cfg(feature = "rustls-common")]
366 TcpStream::Rustls(tls) => tls.get_ref(),
367 }
368 }
369}
370
371impl DerefMut for TcpStream {
372 fn deref_mut(&mut self) -> &mut Self::Target {
373 match self {
374 TcpStream::Plain(plain, _) => plain,
375 #[cfg(feature = "native-tls")]
376 TcpStream::NativeTls(tls) => tls.get_mut(),
377 #[cfg(feature = "openssl")]
378 TcpStream::OpenSsl(tls) => tls.get_mut(),
379 #[cfg(feature = "rustls-common")]
380 TcpStream::Rustls(tls) => tls.get_mut(),
381 }
382 }
383}
384
385macro_rules! fwd_impl {
386 ($self:ident, $method:ident, $($args:expr),*) => {
387 match $self {
388 TcpStream::Plain(plain, _) => plain.$method($($args),*),
389 #[cfg(feature = "native-tls")]
390 TcpStream::NativeTls(tls) => tls.$method($($args),*),
391 #[cfg(feature = "openssl")]
392 TcpStream::OpenSsl(tls) => tls.$method($($args),*),
393 #[cfg(feature = "rustls-common")]
394 TcpStream::Rustls(tls) => tls.$method($($args),*),
395 }
396 };
397}
398
399impl Read for TcpStream {
400 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
401 fwd_impl!(self, read, buf)
402 }
403
404 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
405 fwd_impl!(self, read_vectored, bufs)
406 }
407
408 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
409 fwd_impl!(self, read_to_end, buf)
410 }
411
412 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
413 fwd_impl!(self, read_to_string, buf)
414 }
415
416 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
417 fwd_impl!(self, read_exact, buf)
418 }
419}
420
421impl Write for TcpStream {
422 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
423 fwd_impl!(self, write, buf)
424 }
425
426 fn flush(&mut self) -> io::Result<()> {
427 fwd_impl!(self, flush,)
428 }
429
430 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
431 fwd_impl!(self, write_vectored, bufs)
432 }
433
434 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
435 fwd_impl!(self, write_all, buf)
436 }
437
438 fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
439 fwd_impl!(self, write_fmt, fmt)
440 }
441}
442
443impl fmt::Debug for TcpStream {
444 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
445 f.debug_struct("TcpStream")
446 .field("inner", self.deref())
447 .finish()
448 }
449}
450
451#[allow(clippy::large_enum_variant)]
453#[derive(Debug)]
454pub enum MidHandshakeTlsStream {
455 Plain(TcpStream),
457 #[cfg(feature = "native-tls")]
458 NativeTls(NativeTlsMidHandshakeTlsStream),
460 #[cfg(feature = "openssl")]
461 Openssl(OpenSslMidHandshakeTlsStream),
463 #[cfg(feature = "rustls-common")]
464 Rustls(RustlsMidHandshakeTlsStream),
466}
467
468impl MidHandshakeTlsStream {
469 #[must_use]
471 pub fn get_ref(&self) -> &TcpStream {
472 match self {
473 MidHandshakeTlsStream::Plain(mid) => mid,
474 #[cfg(feature = "native-tls")]
475 MidHandshakeTlsStream::NativeTls(mid) => mid.get_ref(),
476 #[cfg(feature = "openssl")]
477 MidHandshakeTlsStream::Openssl(mid) => mid.get_ref(),
478 #[cfg(feature = "rustls-common")]
479 MidHandshakeTlsStream::Rustls(mid) => mid.get_ref(),
480 }
481 }
482
483 #[must_use]
485 pub fn get_mut(&mut self) -> &mut TcpStream {
486 match self {
487 MidHandshakeTlsStream::Plain(mid) => mid,
488 #[cfg(feature = "native-tls")]
489 MidHandshakeTlsStream::NativeTls(mid) => mid.get_mut(),
490 #[cfg(feature = "openssl")]
491 MidHandshakeTlsStream::Openssl(mid) => mid.get_mut(),
492 #[cfg(feature = "rustls-common")]
493 MidHandshakeTlsStream::Rustls(mid) => mid.get_mut(),
494 }
495 }
496
497 pub fn handshake(self) -> HandshakeResult {
499 Ok(match self {
500 MidHandshakeTlsStream::Plain(mut mid) => {
501 if !mid.try_connect()? {
502 return Err(HandshakeError::WouldBlock(mid.into()));
503 }
504 mid
505 }
506 #[cfg(feature = "native-tls")]
507 MidHandshakeTlsStream::NativeTls(mut mid) => {
508 if !mid.get_mut().try_connect()? {
509 return Err(HandshakeError::WouldBlock(mid.into()));
510 }
511 mid.handshake()?.into()
512 }
513 #[cfg(feature = "openssl")]
514 MidHandshakeTlsStream::Openssl(mut mid) => {
515 if !mid.get_mut().try_connect()? {
516 return Err(HandshakeError::WouldBlock(mid.into()));
517 }
518 mid.handshake()?.into()
519 }
520 #[cfg(feature = "rustls-common")]
521 MidHandshakeTlsStream::Rustls(mut mid) => {
522 if !mid.get_mut().try_connect()? {
523 return Err(HandshakeError::WouldBlock(mid.into()));
524 }
525 mid.handshake()?.into()
526 }
527 })
528 }
529}
530
531impl From<TcpStream> for MidHandshakeTlsStream {
532 fn from(mid: TcpStream) -> Self {
533 MidHandshakeTlsStream::Plain(mid)
534 }
535}
536
537impl fmt::Display for MidHandshakeTlsStream {
538 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
539 f.write_str("MidHandshakeTlsStream")
540 }
541}
542
543#[allow(clippy::large_enum_variant)]
545#[derive(Debug)]
546pub enum HandshakeError {
547 WouldBlock(MidHandshakeTlsStream),
549 Failure(io::Error),
551}
552
553impl HandshakeError {
554 pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
556 match self {
557 Self::WouldBlock(mid) => Ok(mid),
558 Self::Failure(error) => Err(error),
559 }
560 }
561}
562
563impl fmt::Display for HandshakeError {
564 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
565 match self {
566 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
567 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
568 }
569 }
570}
571
572impl Error for HandshakeError {
573 fn source(&self) -> Option<&(dyn Error + 'static)> {
574 match self {
575 HandshakeError::Failure(err) => Some(err),
576 _ => None,
577 }
578 }
579}
580
581impl From<io::Error> for HandshakeError {
582 fn from(err: io::Error) -> Self {
583 HandshakeError::Failure(err)
584 }
585}
586
587mod sys;