Skip to main content

zero_mysql/tokio/
stream.rs

1use core::mem::MaybeUninit;
2use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
3use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7#[cfg(feature = "tokio-tls")]
8use tokio_native_tls::TlsStream;
9
10pub enum Stream {
11    Tcp(BufReader<TcpStream>),
12    #[cfg(feature = "tokio-tls")]
13    Tls(BufReader<TlsStream<TcpStream>>),
14    #[cfg(unix)]
15    Unix(BufReader<UnixStream>),
16}
17
18impl Stream {
19    pub fn tcp(stream: TcpStream) -> Self {
20        Self::Tcp(BufReader::new(stream))
21    }
22
23    #[cfg(unix)]
24    pub fn unix(stream: UnixStream) -> Self {
25        Self::Unix(BufReader::new(stream))
26    }
27
28    #[cfg(feature = "tokio-tls")]
29    pub async fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
30        let tcp = match self {
31            Self::Tcp(buf_reader) => buf_reader.into_inner(),
32            #[cfg(feature = "tokio-tls")]
33            Self::Tls(_) => {
34                return Err(std::io::Error::new(
35                    std::io::ErrorKind::InvalidInput,
36                    "Already using TLS",
37                ));
38            }
39            #[cfg(unix)]
40            Self::Unix(_) => {
41                return Err(std::io::Error::new(
42                    std::io::ErrorKind::InvalidInput,
43                    "TLS not supported for Unix sockets",
44                ));
45            }
46        };
47
48        let connector = native_tls::TlsConnector::new().map_err(std::io::Error::other)?;
49        let connector = tokio_native_tls::TlsConnector::from(connector);
50        let tls_stream = connector
51            .connect(host, tcp)
52            .await
53            .map_err(std::io::Error::other)?;
54
55        Ok(Self::Tls(BufReader::new(tls_stream)))
56    }
57
58    pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59        match self {
60            Self::Tcp(reader) => reader.read_exact(buf).await.map(|_| ()),
61            #[cfg(feature = "tokio-tls")]
62            Self::Tls(reader) => reader.read_exact(buf).await.map(|_| ()),
63            #[cfg(unix)]
64            Self::Unix(reader) => reader.read_exact(buf).await.map(|_| ()),
65        }
66    }
67
68    pub async fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
69        match self {
70            Self::Tcp(reader) => read_buf_exact_impl(reader, buf).await,
71            #[cfg(feature = "tokio-tls")]
72            Self::Tls(reader) => read_buf_exact_impl(reader, buf).await,
73            #[cfg(unix)]
74            Self::Unix(reader) => read_buf_exact_impl(reader, buf).await,
75        }
76    }
77
78    pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
79        match self {
80            Self::Tcp(reader) => reader.get_mut().write_all(buf).await,
81            #[cfg(feature = "tokio-tls")]
82            Self::Tls(reader) => reader.get_mut().write_all(buf).await,
83            #[cfg(unix)]
84            Self::Unix(reader) => reader.get_mut().write_all(buf).await,
85        }
86    }
87
88    pub async fn flush(&mut self) -> std::io::Result<()> {
89        match self {
90            Self::Tcp(reader) => reader.get_mut().flush().await,
91            #[cfg(feature = "tokio-tls")]
92            Self::Tls(reader) => reader.get_mut().flush().await,
93            #[cfg(unix)]
94            Self::Unix(reader) => reader.get_mut().flush().await,
95        }
96    }
97
98    /// Returns true if this is a TCP connection to a loopback address
99    pub fn is_tcp_loopback(&self) -> bool {
100        match self {
101            Self::Tcp(r) => r
102                .get_ref()
103                .peer_addr()
104                .map(|addr| addr.ip().is_loopback())
105                .unwrap_or(false),
106            #[cfg(feature = "tokio-tls")]
107            Self::Tls(r) => r
108                .get_ref()
109                .get_ref()
110                .get_ref()
111                .get_ref()
112                .peer_addr()
113                .map(|addr| addr.ip().is_loopback())
114                .unwrap_or(false),
115            #[cfg(unix)]
116            Self::Unix(_) => false,
117        }
118    }
119}
120
121async fn read_buf_exact_impl<R: AsyncReadExt + Unpin>(
122    reader: &mut R,
123    mut buf: &mut [MaybeUninit<u8>],
124) -> std::io::Result<()> {
125    while !buf.is_empty() {
126        let n = reader.read_buf(&mut buf).await?;
127        if n == 0 {
128            return Err(std::io::Error::new(
129                std::io::ErrorKind::UnexpectedEof,
130                "failed to fill whole buffer",
131            ));
132        }
133    }
134    Ok(())
135}