1#![deny(missing_docs)]
2#![allow(clippy::large_enum_variant, clippy::result_large_err)]
3
4use cfg_if::cfg_if;
50use std::{
51 convert::TryFrom,
52 error::Error,
53 fmt,
54 io::{self, IoSlice, IoSliceMut, Read, Write},
55 net::{TcpStream as StdTcpStream, ToSocketAddrs},
56 ops::{Deref, DerefMut},
57 time::Duration,
58};
59
60#[cfg(feature = "rustls")]
61mod rustls_impl;
62#[cfg(feature = "rustls")]
63pub use rustls_impl::*;
64
65#[cfg(feature = "native-tls")]
66mod native_tls_impl;
67#[cfg(feature = "native-tls")]
68pub use native_tls_impl::*;
69
70#[cfg(feature = "openssl")]
71mod openssl_impl;
72#[cfg(feature = "openssl")]
73pub use openssl_impl::*;
74
75#[cfg(feature = "futures")]
76mod futures;
77#[cfg(feature = "futures")]
78pub use futures::*;
79
80#[non_exhaustive]
82pub enum TcpStream {
83 Plain(StdTcpStream),
85 #[cfg(feature = "native-tls")]
86 NativeTls(NativeTlsStream),
88 #[cfg(feature = "openssl")]
89 Openssl(OpensslStream),
91 #[cfg(feature = "rustls")]
92 Rustls(RustlsStream),
94}
95
96#[derive(Default, Debug, PartialEq)]
98pub struct TLSConfig<'data, 'key, 'chain> {
99 pub identity: Option<Identity<'data, 'key>>,
101 pub cert_chain: Option<&'chain str>,
103}
104
105#[derive(Default, Debug, PartialEq)]
107pub struct OwnedTLSConfig {
108 pub identity: Option<OwnedIdentity>,
110 pub cert_chain: Option<String>,
112}
113
114impl OwnedTLSConfig {
115 #[must_use]
117 pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
118 TLSConfig {
119 identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
120 cert_chain: self.cert_chain.as_deref(),
121 }
122 }
123}
124
125#[derive(Debug, PartialEq)]
129pub enum Identity<'data, 'key> {
130 PKCS12 {
132 der: &'data [u8],
134 password: &'key str,
136 },
137 PKCS8 {
139 pem: &'data [u8],
141 key: &'key [u8],
143 },
144}
145
146#[derive(Debug, PartialEq)]
150pub enum OwnedIdentity {
151 PKCS12 {
153 der: Vec<u8>,
155 password: String,
157 },
158 PKCS8 {
160 pem: Vec<u8>,
162 key: Vec<u8>,
164 },
165}
166
167impl OwnedIdentity {
168 #[must_use]
170 pub fn as_ref(&self) -> Identity<'_, '_> {
171 match self {
172 Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
173 Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
174 }
175 }
176}
177
178pub type HandshakeResult = Result<TcpStream, HandshakeError>;
180
181impl TcpStream {
182 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
184 connect_std(addr, None).and_then(Self::try_from)
185 }
186
187 pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
189 connect_std(addr, Some(timeout)).and_then(Self::try_from)
190 }
191
192 pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
194 Self::try_from(stream)
195 }
196
197 pub fn is_readable(&self) -> io::Result<()> {
199 self.deref().read(&mut []).map(|_| ())
200 }
201
202 pub fn is_writable(&mut self) -> io::Result<()> {
204 is_writable(self.deref_mut())
205 }
206
207 pub fn try_connect(&mut self) -> io::Result<bool> {
212 try_connect(self)
213 }
214
215 pub fn into_tls(
217 self,
218 domain: &str,
219 config: TLSConfig<'_, '_, '_>,
220 ) -> Result<Self, HandshakeError> {
221 into_tls_impl(self, domain, config)
222 }
223
224 #[cfg(feature = "native-tls")]
225 pub fn into_native_tls(
227 self,
228 connector: &NativeTlsConnector,
229 domain: &str,
230 ) -> Result<Self, HandshakeError> {
231 Ok(connector.connect(domain, self.into_plain()?)?.into())
232 }
233
234 #[cfg(feature = "openssl")]
235 pub fn into_openssl(
237 self,
238 connector: &OpensslConnector,
239 domain: &str,
240 ) -> Result<Self, HandshakeError> {
241 Ok(connector.connect(domain, self.into_plain()?)?.into())
242 }
243
244 #[cfg(feature = "rustls")]
245 pub fn into_rustls(
247 self,
248 connector: &RustlsConnector,
249 domain: &str,
250 ) -> Result<Self, HandshakeError> {
251 Ok(connector.connect(domain, self.into_plain()?)?.into())
252 }
253
254 #[allow(irrefutable_let_patterns)]
255 fn into_plain(self) -> Result<StdTcpStream, io::Error> {
256 if let Self::Plain(plain) = self {
257 Ok(plain)
258 } else {
259 Err(io::Error::new(
260 io::ErrorKind::AlreadyExists,
261 "already a TLS stream",
262 ))
263 }
264 }
265}
266
267fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
268 let stream = connect_std_raw(addr, timeout)?;
269 stream.set_nodelay(true)?;
270 Ok(stream)
271}
272
273fn connect_std_raw<A: ToSocketAddrs>(
274 addr: A,
275 timeout: Option<Duration>,
276) -> io::Result<StdTcpStream> {
277 if let Some(timeout) = timeout {
278 let addrs = addr.to_socket_addrs()?;
279 let mut err = None;
280 for addr in addrs {
281 match StdTcpStream::connect_timeout(&addr, timeout) {
282 Ok(stream) => return Ok(stream),
283 Err(error) => err = Some(error),
284 }
285 }
286 Err(err.unwrap_or_else(|| {
287 io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
288 }))
289 } else {
290 StdTcpStream::connect(addr)
291 }
292}
293
294fn try_connect(stream: &mut StdTcpStream) -> io::Result<bool> {
295 match is_writable(stream) {
296 Ok(()) => Ok(true),
297 Err(err)
298 if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected].contains(&err.kind()) =>
299 {
300 Ok(false)
301 }
302 Err(err) => Err(err),
303 }
304}
305
306fn is_writable(stream: &mut StdTcpStream) -> io::Result<()> {
307 stream.write(&[]).map(|_| ())
308}
309
310fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
311 cfg_if! {
312 if #[cfg(feature = "rustls-native-certs")] {
313 into_rustls_impl(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
314 } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
315 into_rustls_impl(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
316 } else if #[cfg(feature = "rustls")] {
317 into_rustls_impl(s, RustlsConnectorConfig::default(), domain, config)
318 } else if #[cfg(feature = "openssl")] {
319 into_openssl_impl(s, domain, config)
320 } else if #[cfg(feature = "native-tls")] {
321 into_native_tls_impl(s, domain, config)
322 } else {
323 let _ = (domain, config);
324 Ok(TcpStream::Plain(s.into_plain()?))
325 }
326 }
327}
328
329impl TryFrom<StdTcpStream> for TcpStream {
330 type Error = io::Error;
331
332 fn try_from(s: StdTcpStream) -> io::Result<Self> {
333 let mut this = Self::Plain(s);
334 this.try_connect()?;
335 Ok(this)
336 }
337}
338
339impl Deref for TcpStream {
340 type Target = StdTcpStream;
341
342 fn deref(&self) -> &Self::Target {
343 match self {
344 Self::Plain(plain) => plain,
345 #[cfg(feature = "native-tls")]
346 Self::NativeTls(tls) => tls.get_ref(),
347 #[cfg(feature = "openssl")]
348 Self::Openssl(tls) => tls.get_ref(),
349 #[cfg(feature = "rustls")]
350 Self::Rustls(tls) => tls.get_ref(),
351 }
352 }
353}
354
355impl DerefMut for TcpStream {
356 fn deref_mut(&mut self) -> &mut Self::Target {
357 match self {
358 Self::Plain(plain) => plain,
359 #[cfg(feature = "native-tls")]
360 Self::NativeTls(tls) => tls.get_mut(),
361 #[cfg(feature = "openssl")]
362 Self::Openssl(tls) => tls.get_mut(),
363 #[cfg(feature = "rustls")]
364 Self::Rustls(tls) => tls.get_mut(),
365 }
366 }
367}
368
369macro_rules! fwd_impl {
370 ($self:ident, $method:ident, $($args:expr),*) => {
371 match $self {
372 Self::Plain(plain) => plain.$method($($args),*),
373 #[cfg(feature = "native-tls")]
374 Self::NativeTls(tls) => tls.$method($($args),*),
375 #[cfg(feature = "openssl")]
376 Self::Openssl(tls) => tls.$method($($args),*),
377 #[cfg(feature = "rustls")]
378 Self::Rustls(tls) => tls.$method($($args),*),
379 }
380 };
381}
382
383impl Read for TcpStream {
384 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
385 fwd_impl!(self, read, buf)
386 }
387
388 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
389 fwd_impl!(self, read_vectored, bufs)
390 }
391
392 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
393 fwd_impl!(self, read_to_end, buf)
394 }
395
396 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
397 fwd_impl!(self, read_to_string, buf)
398 }
399
400 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
401 fwd_impl!(self, read_exact, buf)
402 }
403}
404
405impl Write for TcpStream {
406 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
407 fwd_impl!(self, write, buf)
408 }
409
410 fn flush(&mut self) -> io::Result<()> {
411 fwd_impl!(self, flush,)
412 }
413
414 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
415 fwd_impl!(self, write_vectored, bufs)
416 }
417
418 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
419 fwd_impl!(self, write_all, buf)
420 }
421
422 fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
423 fwd_impl!(self, write_fmt, fmt)
424 }
425}
426
427impl fmt::Debug for TcpStream {
428 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
429 f.debug_struct("TcpStream")
430 .field("inner", self.deref())
431 .finish()
432 }
433}
434
435#[derive(Debug)]
437pub enum MidHandshakeTlsStream {
438 Plain(TcpStream),
440 #[cfg(feature = "native-tls")]
441 NativeTls(NativeTlsMidHandshakeTlsStream),
443 #[cfg(feature = "openssl")]
444 Openssl(OpensslMidHandshakeTlsStream),
446 #[cfg(feature = "rustls")]
447 Rustls(RustlsMidHandshakeTlsStream),
449}
450
451impl MidHandshakeTlsStream {
452 #[must_use]
454 pub fn get_ref(&self) -> &StdTcpStream {
455 match self {
456 Self::Plain(mid) => mid,
457 #[cfg(feature = "native-tls")]
458 Self::NativeTls(mid) => mid.get_ref(),
459 #[cfg(feature = "openssl")]
460 Self::Openssl(mid) => mid.get_ref(),
461 #[cfg(feature = "rustls")]
462 Self::Rustls(mid) => mid.get_ref(),
463 }
464 }
465
466 #[must_use]
468 pub fn get_mut(&mut self) -> &mut StdTcpStream {
469 match self {
470 Self::Plain(mid) => mid,
471 #[cfg(feature = "native-tls")]
472 Self::NativeTls(mid) => mid.get_mut(),
473 #[cfg(feature = "openssl")]
474 Self::Openssl(mid) => mid.get_mut(),
475 #[cfg(feature = "rustls")]
476 Self::Rustls(mid) => mid.get_mut(),
477 }
478 }
479
480 pub fn handshake(mut self) -> HandshakeResult {
482 if !try_connect(self.get_mut())? {
483 return Err(HandshakeError::WouldBlock(self));
484 }
485
486 Ok(match self {
487 Self::Plain(mid) => mid,
488 #[cfg(feature = "native-tls")]
489 Self::NativeTls(mid) => mid.handshake()?.into(),
490 #[cfg(feature = "openssl")]
491 Self::Openssl(mid) => mid.handshake()?.into(),
492 #[cfg(feature = "rustls")]
493 Self::Rustls(mid) => mid.handshake()?.into(),
494 })
495 }
496}
497
498impl From<TcpStream> for MidHandshakeTlsStream {
499 fn from(mid: TcpStream) -> Self {
500 Self::Plain(mid)
501 }
502}
503
504impl fmt::Display for MidHandshakeTlsStream {
505 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
506 f.write_str("MidHandshakeTlsStream")
507 }
508}
509
510#[derive(Debug)]
512pub enum HandshakeError {
513 WouldBlock(MidHandshakeTlsStream),
515 Failure(io::Error),
517}
518
519impl HandshakeError {
520 pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
522 match self {
523 Self::WouldBlock(mid) => Ok(mid),
524 Self::Failure(error) => Err(error),
525 }
526 }
527}
528
529impl fmt::Display for HandshakeError {
530 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
531 match self {
532 Self::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
533 Self::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
534 }
535 }
536}
537
538impl Error for HandshakeError {
539 fn source(&self) -> Option<&(dyn Error + 'static)> {
540 match self {
541 Self::Failure(err) => Some(err),
542 _ => None,
543 }
544 }
545}
546
547impl From<io::Error> for HandshakeError {
548 fn from(err: io::Error) -> Self {
549 Self::Failure(err)
550 }
551}
552
553mod sys;