#![deny(missing_docs)]
#![warn(rust_2018_idioms)]
#![doc(html_root_url = "https://docs.rs/tcp-stream/0.10.1/")]
use cfg_if::cfg_if;
use mio::{Interest, Registry, Token, event::Source, net::TcpStream as MioTcpStream};
use std::{
error::Error,
fmt,
io::{self, Read, Write},
net::{self, ToSocketAddrs},
ops::{Deref, DerefMut},
};
#[cfg(feature = "native-tls")]
pub use native_tls::TlsConnector as NativeTlsConnector;
#[cfg(feature = "native-tls")]
pub type NativeTlsStream = native_tls::TlsStream<MioTcpStream>;
#[cfg(feature = "native-tls")]
pub type NativeTlsMidHandshakeTlsStream = native_tls::MidHandshakeTlsStream<MioTcpStream>;
#[cfg(feature = "native-tls")]
pub type NativeTlsHandshakeError = native_tls::HandshakeError<MioTcpStream>;
#[cfg(feature = "openssl")]
pub use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod as OpenSslMethod};
#[cfg(feature = "openssl")]
pub type OpenSslStream = openssl::ssl::SslStream<MioTcpStream>;
#[cfg(feature = "openssl")]
pub type OpenSslMidHandshakeTlsStream = openssl::ssl::MidHandshakeSslStream<MioTcpStream>;
#[cfg(feature = "openssl")]
pub type OpenSslHandshakeError = openssl::ssl::HandshakeError<MioTcpStream>;
#[cfg(feature = "openssl")]
pub type OpenSslErrorStack = openssl::error::ErrorStack;
#[cfg(feature = "rustls-connector")]
pub use rustls_connector::RustlsConnector;
#[cfg(feature = "rustls-connector")]
pub type RustlsStream = rustls_connector::TlsStream<MioTcpStream>;
#[cfg(feature = "rustls-connector")]
pub type RustlsMidHandshakeTlsStream = rustls_connector::MidHandshakeTlsStream<MioTcpStream>;
#[cfg(feature = "rustls-connector")]
pub type RustlsHandshakeError = rustls_connector::HandshakeError<MioTcpStream>;
#[allow(clippy::large_enum_variant)]
pub enum TcpStream {
Plain(MioTcpStream),
#[cfg(feature = "native-tls")]
NativeTls(NativeTlsStream),
#[cfg(feature = "openssl")]
OpenSsl(OpenSslStream),
#[cfg(feature = "rustls-connector")]
Rustls(RustlsStream),
}
#[derive(Debug, PartialEq)]
pub struct Identity<'a, 'b> {
pub der: &'a [u8],
pub password: &'b str,
}
impl TcpStream {
pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let addrs = addr.to_socket_addrs()?;
let mut err = None;
for addr in addrs {
match MioTcpStream::connect(addr) {
Ok(stream) => return Ok(stream.into()),
Err(error) => err = Some(error),
}
}
Err(err.unwrap_or_else(|| {
io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
}))
}
pub fn from_std(stream: net::TcpStream) -> Self {
MioTcpStream::from_std(stream).into()
}
pub fn into_tls(self, domain: &str, identity: Option<Identity<'_, '_>>) -> Result<Self, HandshakeError> {
into_tls_impl(self, domain, identity)
}
#[cfg(feature = "native-tls")]
pub fn into_native_tls(
self,
connector: NativeTlsConnector,
domain: &str,
) -> Result<Self, HandshakeError> {
Ok(connector.connect(domain, self.into_plain()?)?.into())
}
#[cfg(feature = "openssl")]
pub fn into_openssl(
self,
connector: OpenSslConnector,
domain: &str,
) -> Result<Self, HandshakeError> {
Ok(connector.connect(domain, self.into_plain()?)?.into())
}
#[cfg(feature = "rustls-connector")]
pub fn into_rustls(
self,
connector: RustlsConnector,
domain: &str,
) -> Result<Self, HandshakeError> {
Ok(connector.connect(domain, self.into_plain()?)?.into())
}
#[allow(irrefutable_let_patterns)]
fn into_plain(self) -> Result<MioTcpStream, io::Error> {
if let TcpStream::Plain(plain) = self {
Ok(plain)
} else {
Err(io::Error::new(
io::ErrorKind::AlreadyExists,
"already a TLS stream",
))
}
}
}
#[cfg(feature = "rustls-connector")]
fn into_rustls_common(s: TcpStream, c: RustlsConnector, domain: &str, _: Option<Identity<'_, '_>>) -> Result<TcpStream, HandshakeError> {
s.into_rustls(c, domain)
}
cfg_if! {
if #[cfg(feature = "rustls-native-certs")] {
fn into_tls_impl(s: TcpStream, domain: &str, identity: Option<Identity<'_, '_>>) -> Result<TcpStream, HandshakeError> {
into_rustls_common(s, RustlsConnector::new_with_native_certs()?, domain, identity)
}
} else if #[cfg(feature = "rustls-webpki-roots-certs")] {
fn into_tls_impl(s: TcpStream, domain: &str, identity: Option<Identity<'_, '_>>) -> Result<TcpStream, HandshakeError> {
into_rustls_common(s, RustlsConnector::new_with_webpki_roots_certs(), domain, identity)
}
} else if #[cfg(feature = "rustls-connector")] {
fn into_tls_impl(s: TcpStream, domain: &str, identity: Option<Identity<'_, '_>>) -> Result<TcpStream, HandshakeError> {
into_rustls_common(s, RustlsConnector::default(), domain, identity)
}
} else if #[cfg(feature = "openssl")] {
fn into_tls_impl(s: TcpStream, domain: &str, identity: Option<Identity<'_, '_>>) -> Result<TcpStream, HandshakeError> {
let mut builder = OpenSslConnector::builder(OpenSslMethod::tls())?;
if let Some(identity) = identity {
let identity = openssl::pkcs12::Pkcs12::from_der(identity.der)?.parse(identity.password)?;
builder.set_certificate(&identity.cert)?;
builder.set_private_key(&identity.pkey)?;
if let Some(chain) = identity.chain.as_ref() {
for cert in chain.iter().rev() {
builder.add_extra_chain_cert(cert.to_owned())?;
}
}
}
s.into_openssl(builder.build(), domain)
}
} else if #[cfg(feature = "native-tls")] {
fn into_tls_impl(s: TcpStream, domain: &str, identity: Option<Identity<'_, '_>>) -> Result<TcpStream, HandshakeError> {
let mut builder = NativeTlsConnector::builder();
if let Some(identity) = identity {
builder.identity(native_tls::Identity::from_pkcs12(identity.der, identity.password).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?);
}
s.into_native_tls(builder.build().map_err(|e| io::Error::new(io::ErrorKind::Other, e))?, domain)
}
} else {
fn into_tls_impl(s: TcpStream, _domain: &str, _: Option<Identity<'_, '_>>) -> Result<TcpStream, HandshakeError> {
Ok(TcpStream::Plain(s.into_plain()?))
}
}
}
impl From<MioTcpStream> for TcpStream {
fn from(s: MioTcpStream) -> Self {
TcpStream::Plain(s)
}
}
#[cfg(feature = "native-tls")]
impl From<NativeTlsStream> for TcpStream {
fn from(s: NativeTlsStream) -> Self {
TcpStream::NativeTls(s)
}
}
#[cfg(feature = "openssl")]
impl From<OpenSslStream> for TcpStream {
fn from(s: OpenSslStream) -> Self {
TcpStream::OpenSsl(s)
}
}
#[cfg(feature = "rustls-connector")]
impl From<RustlsStream> for TcpStream {
fn from(s: RustlsStream) -> Self {
TcpStream::Rustls(s)
}
}
impl Deref for TcpStream {
type Target = MioTcpStream;
fn deref(&self) -> &Self::Target {
match self {
TcpStream::Plain(plain) => plain,
#[cfg(feature = "native-tls")]
TcpStream::NativeTls(tls) => tls.get_ref(),
#[cfg(feature = "openssl")]
TcpStream::OpenSsl(tls) => tls.get_ref(),
#[cfg(feature = "rustls-connector")]
TcpStream::Rustls(tls) => tls.get_ref(),
}
}
}
impl DerefMut for TcpStream {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
TcpStream::Plain(plain) => plain,
#[cfg(feature = "native-tls")]
TcpStream::NativeTls(tls) => tls.get_mut(),
#[cfg(feature = "openssl")]
TcpStream::OpenSsl(tls) => tls.get_mut(),
#[cfg(feature = "rustls-connector")]
TcpStream::Rustls(tls) => tls.get_mut(),
}
}
}
impl Read for TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
TcpStream::Plain(ref mut plain) => plain.read(buf),
#[cfg(feature = "native-tls")]
TcpStream::NativeTls(ref mut tls) => tls.read(buf),
#[cfg(feature = "openssl")]
TcpStream::OpenSsl(ref mut tls) => tls.read(buf),
#[cfg(feature = "rustls-connector")]
TcpStream::Rustls(ref mut tls) => tls.read(buf),
}
}
}
impl Write for TcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
TcpStream::Plain(ref mut plain) => plain.write(buf),
#[cfg(feature = "native-tls")]
TcpStream::NativeTls(ref mut tls) => tls.write(buf),
#[cfg(feature = "openssl")]
TcpStream::OpenSsl(ref mut tls) => tls.write(buf),
#[cfg(feature = "rustls-connector")]
TcpStream::Rustls(ref mut tls) => tls.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
TcpStream::Plain(ref mut plain) => plain.flush(),
#[cfg(feature = "native-tls")]
TcpStream::NativeTls(ref mut tls) => tls.flush(),
#[cfg(feature = "openssl")]
TcpStream::OpenSsl(ref mut tls) => tls.flush(),
#[cfg(feature = "rustls-connector")]
TcpStream::Rustls(ref mut tls) => tls.flush(),
}
}
}
impl Source for TcpStream {
fn register(
&mut self,
registry: &Registry,
token: Token,
interests: Interest,
) -> io::Result<()> {
<MioTcpStream as Source>::register(self, registry, token, interests)
}
fn reregister(
&mut self,
registry: &Registry,
token: Token,
interests: Interest,
) -> io::Result<()> {
<MioTcpStream as Source>::reregister(self, registry, token, interests)
}
fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
<MioTcpStream as Source>::deregister(self, registry)
}
}
impl fmt::Debug for TcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<MioTcpStream as fmt::Debug>::fmt(self, f)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub enum MidHandshakeTlsStream {
Plain(MioTcpStream),
#[cfg(feature = "native-tls")]
NativeTls(NativeTlsMidHandshakeTlsStream),
#[cfg(feature = "openssl")]
Openssl(OpenSslMidHandshakeTlsStream),
#[cfg(feature = "rustls-connector")]
Rustls(RustlsMidHandshakeTlsStream),
}
impl MidHandshakeTlsStream {
pub fn get_ref(&self) -> &MioTcpStream {
match self {
MidHandshakeTlsStream::Plain(mid) => mid,
#[cfg(feature = "native-tls")]
MidHandshakeTlsStream::NativeTls(mid) => mid.get_ref(),
#[cfg(feature = "openssl")]
MidHandshakeTlsStream::Openssl(mid) => mid.get_ref(),
#[cfg(feature = "rustls-connector")]
MidHandshakeTlsStream::Rustls(mid) => mid.get_ref(),
}
}
pub fn get_mut(&mut self) -> &MioTcpStream {
match self {
MidHandshakeTlsStream::Plain(mid) => mid,
#[cfg(feature = "native-tls")]
MidHandshakeTlsStream::NativeTls(mid) => mid.get_mut(),
#[cfg(feature = "openssl")]
MidHandshakeTlsStream::Openssl(mid) => mid.get_mut(),
#[cfg(feature = "rustls-connector")]
MidHandshakeTlsStream::Rustls(mid) => mid.get_mut(),
}
}
pub fn handshake(self) -> Result<TcpStream, HandshakeError> {
Ok(match self {
MidHandshakeTlsStream::Plain(mid) => TcpStream::Plain(mid),
#[cfg(feature = "native-tls")]
MidHandshakeTlsStream::NativeTls(mid) => mid.handshake()?.into(),
#[cfg(feature = "openssl")]
MidHandshakeTlsStream::Openssl(mid) => mid.handshake()?.into(),
#[cfg(feature = "rustls-connector")]
MidHandshakeTlsStream::Rustls(mid) => mid.handshake()?.into(),
})
}
}
#[cfg(feature = "native-tls")]
impl From<NativeTlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
fn from(mid: NativeTlsMidHandshakeTlsStream) -> Self {
MidHandshakeTlsStream::NativeTls(mid)
}
}
#[cfg(feature = "openssl")]
impl From<OpenSslMidHandshakeTlsStream> for MidHandshakeTlsStream {
fn from(mid: OpenSslMidHandshakeTlsStream) -> Self {
MidHandshakeTlsStream::Openssl(mid)
}
}
#[cfg(feature = "rustls-connector")]
impl From<RustlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
fn from(mid: RustlsMidHandshakeTlsStream) -> Self {
MidHandshakeTlsStream::Rustls(mid)
}
}
impl fmt::Display for MidHandshakeTlsStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MidHandshakeTlsStream")
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub enum HandshakeError {
WouldBlock(MidHandshakeTlsStream),
Failure(io::Error),
}
impl fmt::Display for HandshakeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HandshakeError::WouldBlock(_) => write!(f, "WouldBlock hit during handshake"),
HandshakeError::Failure(err) => write!(f, "IO error: {}", err),
}
}
}
impl Error for HandshakeError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
HandshakeError::Failure(err) => Some(err),
_ => None,
}
}
}
#[cfg(feature = "native-tls")]
impl From<NativeTlsHandshakeError> for HandshakeError {
fn from(error: NativeTlsHandshakeError) -> Self {
match error {
native_tls::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
native_tls::HandshakeError::Failure(failure) => {
HandshakeError::Failure(io::Error::new(io::ErrorKind::Other, failure))
}
}
}
}
#[cfg(feature = "openssl")]
impl From<OpenSslHandshakeError> for HandshakeError {
fn from(error: OpenSslHandshakeError) -> Self {
match error {
openssl::ssl::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
openssl::ssl::HandshakeError::Failure(failure) => {
HandshakeError::Failure(io::Error::new(io::ErrorKind::Other, failure.into_error()))
}
openssl::ssl::HandshakeError::SetupFailure(failure) => {
failure.into()
}
}
}
}
#[cfg(feature = "openssl")]
impl From<OpenSslErrorStack> for HandshakeError {
fn from(error: OpenSslErrorStack) -> Self {
Self::Failure(error.into())
}
}
#[cfg(feature = "rustls-connector")]
impl From<RustlsHandshakeError> for HandshakeError {
fn from(error: RustlsHandshakeError) -> Self {
match error {
rustls_connector::HandshakeError::WouldBlock(mid) => {
HandshakeError::WouldBlock(mid.into())
}
rustls_connector::HandshakeError::Failure(failure) => HandshakeError::Failure(failure),
}
}
}
impl From<io::Error> for HandshakeError {
fn from(err: io::Error) -> Self {
HandshakeError::Failure(err)
}
}
#[cfg(unix)]
mod unix;