1#![deny(missing_docs)]
2#![allow(clippy::result_large_err)]
3
4use cfg_if::cfg_if;
50use std::{
51 convert::TryFrom,
52 error::Error,
53 fmt,
54 io::{self, IoSlice, IoSliceMut, Read, Write},
55 net::{TcpStream as StdTcpStream, ToSocketAddrs},
56 ops::{Deref, DerefMut},
57 time::Duration,
58};
59
60#[cfg(feature = "rustls-common")]
61mod rustls_impl;
62#[cfg(feature = "rustls-common")]
63pub use rustls_impl::*;
64
65#[cfg(feature = "native-tls")]
66mod native_tls_impl;
67#[cfg(feature = "native-tls")]
68pub use native_tls_impl::*;
69
70#[cfg(feature = "openssl")]
71mod openssl_impl;
72#[cfg(feature = "openssl")]
73pub use openssl_impl::*;
74
75#[cfg(feature = "futures")]
76mod futures;
77#[cfg(feature = "futures")]
78pub use futures::*;
79
80pub enum TcpStream {
82 Plain(StdTcpStream, bool),
84 #[cfg(feature = "native-tls")]
85 NativeTls(Box<NativeTlsStream>),
87 #[cfg(feature = "openssl")]
88 OpenSsl(Box<OpenSslStream>),
90 #[cfg(feature = "rustls-common")]
91 Rustls(Box<RustlsStream>),
93}
94
95#[derive(Default, Debug, PartialEq)]
97pub struct TLSConfig<'data, 'key, 'chain> {
98 pub identity: Option<Identity<'data, 'key>>,
100 pub cert_chain: Option<&'chain str>,
102}
103
104#[derive(Default, Debug, PartialEq)]
106pub struct OwnedTLSConfig {
107 pub identity: Option<OwnedIdentity>,
109 pub cert_chain: Option<String>,
111}
112
113impl OwnedTLSConfig {
114 #[must_use]
116 pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
117 TLSConfig {
118 identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
119 cert_chain: self.cert_chain.as_deref(),
120 }
121 }
122}
123
124#[derive(Debug, PartialEq)]
128pub enum Identity<'data, 'key> {
129 PKCS12 {
131 der: &'data [u8],
133 password: &'key str,
135 },
136 PKCS8 {
138 pem: &'data [u8],
140 key: &'key [u8],
142 },
143}
144
145#[derive(Debug, PartialEq)]
149pub enum OwnedIdentity {
150 PKCS12 {
152 der: Vec<u8>,
154 password: String,
156 },
157 PKCS8 {
159 pem: Vec<u8>,
161 key: Vec<u8>,
163 },
164}
165
166impl OwnedIdentity {
167 #[must_use]
169 pub fn as_ref(&self) -> Identity<'_, '_> {
170 match self {
171 Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
172 Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
173 }
174 }
175}
176
177pub type HandshakeResult = Result<TcpStream, HandshakeError>;
179
180impl TcpStream {
181 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
183 connect_std(addr, None).and_then(Self::try_from)
184 }
185
186 pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
188 connect_std(addr, Some(timeout)).and_then(Self::try_from)
189 }
190
191 pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
193 Self::try_from(stream)
194 }
195
196 #[must_use]
198 #[allow(irrefutable_let_patterns)]
199 pub fn is_connected(&self) -> bool {
200 if let Self::Plain(_, connected) = self {
201 *connected
202 } else {
203 true
204 }
205 }
206
207 pub fn is_readable(&self) -> io::Result<()> {
209 self.deref().read(&mut []).map(|_| ())
210 }
211
212 pub fn is_writable(&self) -> io::Result<()> {
214 self.deref().write(&[]).map(|_| ())
215 }
216
217 #[allow(irrefutable_let_patterns)]
222 pub fn try_connect(&mut self) -> io::Result<bool> {
223 if self.is_connected() {
224 return Ok(true);
225 }
226 match self.is_writable() {
227 Ok(()) => {
228 if let Self::Plain(_, connected) = self {
229 *connected = true;
230 }
231 Ok(true)
232 }
233 Err(err)
234 if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected]
235 .contains(&err.kind()) =>
236 {
237 Ok(false)
238 }
239 Err(err) => Err(err),
240 }
241 }
242
243 pub fn into_tls(
245 self,
246 domain: &str,
247 config: TLSConfig<'_, '_, '_>,
248 ) -> Result<Self, HandshakeError> {
249 into_tls_impl(self, domain, config)
250 }
251
252 #[cfg(feature = "native-tls")]
253 pub fn into_native_tls(
255 self,
256 connector: &NativeTlsConnector,
257 domain: &str,
258 ) -> Result<Self, HandshakeError> {
259 Ok(connector.connect(domain, self.into_plain()?)?.into())
260 }
261
262 #[cfg(feature = "openssl")]
263 pub fn into_openssl(
265 self,
266 connector: &OpenSslConnector,
267 domain: &str,
268 ) -> Result<Self, HandshakeError> {
269 Ok(connector.connect(domain, self.into_plain()?)?.into())
270 }
271
272 #[cfg(feature = "rustls-common")]
273 pub fn into_rustls(
275 self,
276 connector: &RustlsConnector,
277 domain: &str,
278 ) -> Result<Self, HandshakeError> {
279 Ok(connector.connect(domain, self.into_plain()?)?.into())
280 }
281
282 #[allow(irrefutable_let_patterns)]
283 fn into_plain(self) -> Result<TcpStream, io::Error> {
284 if let TcpStream::Plain(plain, connected) = self {
285 Ok(TcpStream::Plain(plain, connected))
286 } else {
287 Err(io::Error::new(
288 io::ErrorKind::AlreadyExists,
289 "already a TLS stream",
290 ))
291 }
292 }
293}
294
295fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
296 let stream = connect_std_raw(addr, timeout)?;
297 stream.set_nodelay(true)?;
298 Ok(stream)
299}
300
301fn connect_std_raw<A: ToSocketAddrs>(
302 addr: A,
303 timeout: Option<Duration>,
304) -> io::Result<StdTcpStream> {
305 if let Some(timeout) = timeout {
306 let addrs = addr.to_socket_addrs()?;
307 let mut err = None;
308 for addr in addrs {
309 match StdTcpStream::connect_timeout(&addr, timeout) {
310 Ok(stream) => return Ok(stream),
311 Err(error) => err = Some(error),
312 }
313 }
314 Err(err.unwrap_or_else(|| {
315 io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
316 }))
317 } else {
318 StdTcpStream::connect(addr)
319 }
320}
321
322cfg_if! {
323 if #[cfg(feature = "rustls-native-certs")] {
324 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
325 into_rustls_impl(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
326 }
327 } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
328 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
329 into_rustls_impl(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
330 }
331 } else if #[cfg(feature = "rustls-common")] {
332 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
333 into_rustls_impl(s, RustlsConnectorConfig::default(), domain, config)
334 }
335 } else if #[cfg(feature = "openssl")] {
336 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
337 into_openssl_impl(s, domain, config)
338 }
339 } else if #[cfg(feature = "native-tls")] {
340 fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
341 into_native_tls_impl(s, domain, config)
342 }
343 } else {
344 fn into_tls_impl(s: TcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> HandshakeResult {
345 Ok(s.into_plain()?)
346 }
347 }
348}
349
350impl TryFrom<StdTcpStream> for TcpStream {
351 type Error = io::Error;
352
353 fn try_from(s: StdTcpStream) -> io::Result<Self> {
354 let mut this = TcpStream::Plain(s, false);
355 this.try_connect()?;
356 Ok(this)
357 }
358}
359
360impl Deref for TcpStream {
361 type Target = StdTcpStream;
362
363 fn deref(&self) -> &Self::Target {
364 match self {
365 TcpStream::Plain(plain, _) => plain,
366 #[cfg(feature = "native-tls")]
367 TcpStream::NativeTls(tls) => tls.get_ref(),
368 #[cfg(feature = "openssl")]
369 TcpStream::OpenSsl(tls) => tls.get_ref(),
370 #[cfg(feature = "rustls-common")]
371 TcpStream::Rustls(tls) => tls.get_ref(),
372 }
373 }
374}
375
376impl DerefMut for TcpStream {
377 fn deref_mut(&mut self) -> &mut Self::Target {
378 match self {
379 TcpStream::Plain(plain, _) => plain,
380 #[cfg(feature = "native-tls")]
381 TcpStream::NativeTls(tls) => tls.get_mut(),
382 #[cfg(feature = "openssl")]
383 TcpStream::OpenSsl(tls) => tls.get_mut(),
384 #[cfg(feature = "rustls-common")]
385 TcpStream::Rustls(tls) => tls.get_mut(),
386 }
387 }
388}
389
390macro_rules! fwd_impl {
391 ($self:ident, $method:ident, $($args:expr),*) => {
392 match $self {
393 TcpStream::Plain(plain, _) => plain.$method($($args),*),
394 #[cfg(feature = "native-tls")]
395 TcpStream::NativeTls(tls) => tls.$method($($args),*),
396 #[cfg(feature = "openssl")]
397 TcpStream::OpenSsl(tls) => tls.$method($($args),*),
398 #[cfg(feature = "rustls-common")]
399 TcpStream::Rustls(tls) => tls.$method($($args),*),
400 }
401 };
402}
403
404impl Read for TcpStream {
405 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
406 fwd_impl!(self, read, buf)
407 }
408
409 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
410 fwd_impl!(self, read_vectored, bufs)
411 }
412
413 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
414 fwd_impl!(self, read_to_end, buf)
415 }
416
417 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
418 fwd_impl!(self, read_to_string, buf)
419 }
420
421 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
422 fwd_impl!(self, read_exact, buf)
423 }
424}
425
426impl Write for TcpStream {
427 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
428 fwd_impl!(self, write, buf)
429 }
430
431 fn flush(&mut self) -> io::Result<()> {
432 fwd_impl!(self, flush,)
433 }
434
435 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
436 fwd_impl!(self, write_vectored, bufs)
437 }
438
439 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
440 fwd_impl!(self, write_all, buf)
441 }
442
443 fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
444 fwd_impl!(self, write_fmt, fmt)
445 }
446}
447
448impl fmt::Debug for TcpStream {
449 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
450 f.debug_struct("TcpStream")
451 .field("inner", self.deref())
452 .finish()
453 }
454}
455
456#[allow(clippy::large_enum_variant)]
458#[derive(Debug)]
459pub enum MidHandshakeTlsStream {
460 Plain(TcpStream),
462 #[cfg(feature = "native-tls")]
463 NativeTls(NativeTlsMidHandshakeTlsStream),
465 #[cfg(feature = "openssl")]
466 Openssl(OpenSslMidHandshakeTlsStream),
468 #[cfg(feature = "rustls-common")]
469 Rustls(RustlsMidHandshakeTlsStream),
471}
472
473impl MidHandshakeTlsStream {
474 #[must_use]
476 pub fn get_ref(&self) -> &TcpStream {
477 match self {
478 MidHandshakeTlsStream::Plain(mid) => mid,
479 #[cfg(feature = "native-tls")]
480 MidHandshakeTlsStream::NativeTls(mid) => mid.get_ref(),
481 #[cfg(feature = "openssl")]
482 MidHandshakeTlsStream::Openssl(mid) => mid.get_ref(),
483 #[cfg(feature = "rustls-common")]
484 MidHandshakeTlsStream::Rustls(mid) => mid.get_ref(),
485 }
486 }
487
488 #[must_use]
490 pub fn get_mut(&mut self) -> &mut TcpStream {
491 match self {
492 MidHandshakeTlsStream::Plain(mid) => mid,
493 #[cfg(feature = "native-tls")]
494 MidHandshakeTlsStream::NativeTls(mid) => mid.get_mut(),
495 #[cfg(feature = "openssl")]
496 MidHandshakeTlsStream::Openssl(mid) => mid.get_mut(),
497 #[cfg(feature = "rustls-common")]
498 MidHandshakeTlsStream::Rustls(mid) => mid.get_mut(),
499 }
500 }
501
502 pub fn handshake(self) -> HandshakeResult {
504 Ok(match self {
505 MidHandshakeTlsStream::Plain(mut mid) => {
506 if !mid.try_connect()? {
507 return Err(HandshakeError::WouldBlock(mid.into()));
508 }
509 mid
510 }
511 #[cfg(feature = "native-tls")]
512 MidHandshakeTlsStream::NativeTls(mut mid) => {
513 if !mid.get_mut().try_connect()? {
514 return Err(HandshakeError::WouldBlock(mid.into()));
515 }
516 mid.handshake()?.into()
517 }
518 #[cfg(feature = "openssl")]
519 MidHandshakeTlsStream::Openssl(mut mid) => {
520 if !mid.get_mut().try_connect()? {
521 return Err(HandshakeError::WouldBlock(mid.into()));
522 }
523 mid.handshake()?.into()
524 }
525 #[cfg(feature = "rustls-common")]
526 MidHandshakeTlsStream::Rustls(mut mid) => {
527 if !mid.get_mut().try_connect()? {
528 return Err(HandshakeError::WouldBlock(mid.into()));
529 }
530 mid.handshake()?.into()
531 }
532 })
533 }
534}
535
536impl From<TcpStream> for MidHandshakeTlsStream {
537 fn from(mid: TcpStream) -> Self {
538 MidHandshakeTlsStream::Plain(mid)
539 }
540}
541
542impl fmt::Display for MidHandshakeTlsStream {
543 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
544 f.write_str("MidHandshakeTlsStream")
545 }
546}
547
548#[allow(clippy::large_enum_variant)]
550#[derive(Debug)]
551pub enum HandshakeError {
552 WouldBlock(MidHandshakeTlsStream),
554 Failure(io::Error),
556}
557
558impl HandshakeError {
559 pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
561 match self {
562 Self::WouldBlock(mid) => Ok(mid),
563 Self::Failure(error) => Err(error),
564 }
565 }
566}
567
568impl fmt::Display for HandshakeError {
569 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
570 match self {
571 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
572 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
573 }
574 }
575}
576
577impl Error for HandshakeError {
578 fn source(&self) -> Option<&(dyn Error + 'static)> {
579 match self {
580 HandshakeError::Failure(err) => Some(err),
581 _ => None,
582 }
583 }
584}
585
586impl From<io::Error> for HandshakeError {
587 fn from(err: io::Error) -> Self {
588 HandshakeError::Failure(err)
589 }
590}
591
592mod sys;