zero_mysql/tokio/
stream.rs1use core::mem::MaybeUninit;
2use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
3use tokio::net::{TcpStream, UnixStream};
4
5#[cfg(feature = "tokio-tls")]
6use tokio_native_tls::TlsStream;
7
8pub enum Stream {
9 Tcp(BufReader<TcpStream>),
10 #[cfg(feature = "tokio-tls")]
11 Tls(BufReader<TlsStream<TcpStream>>),
12 Unix(BufReader<UnixStream>),
13}
14
15impl Stream {
16 pub fn tcp(stream: TcpStream) -> Self {
17 Self::Tcp(BufReader::new(stream))
18 }
19
20 pub fn unix(stream: UnixStream) -> Self {
21 Self::Unix(BufReader::new(stream))
22 }
23
24 #[cfg(feature = "tokio-tls")]
25 pub async fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
26 let tcp = match self {
27 Self::Tcp(buf_reader) => buf_reader.into_inner(),
28 #[cfg(feature = "tokio-tls")]
29 Self::Tls(_) => {
30 return Err(std::io::Error::new(
31 std::io::ErrorKind::InvalidInput,
32 "Already using TLS",
33 ));
34 }
35 Self::Unix(_) => {
36 return Err(std::io::Error::new(
37 std::io::ErrorKind::InvalidInput,
38 "TLS not supported for Unix sockets",
39 ));
40 }
41 };
42
43 let connector = native_tls::TlsConnector::new()
44 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
45 let connector = tokio_native_tls::TlsConnector::from(connector);
46 let tls_stream = connector
47 .connect(host, tcp)
48 .await
49 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
50
51 Ok(Self::Tls(BufReader::new(tls_stream)))
52 }
53
54 pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
55 match self {
56 Self::Tcp(reader) => reader.read_exact(buf).await.map(|_| ()),
57 #[cfg(feature = "tokio-tls")]
58 Self::Tls(reader) => reader.read_exact(buf).await.map(|_| ()),
59 Self::Unix(reader) => reader.read_exact(buf).await.map(|_| ()),
60 }
61 }
62
63 pub async fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
64 match self {
65 Self::Tcp(reader) => read_buf_exact_impl(reader, buf).await,
66 #[cfg(feature = "tokio-tls")]
67 Self::Tls(reader) => read_buf_exact_impl(reader, buf).await,
68 Self::Unix(reader) => read_buf_exact_impl(reader, buf).await,
69 }
70 }
71
72 pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
73 match self {
74 Self::Tcp(reader) => reader.get_mut().write_all(buf).await,
75 #[cfg(feature = "tokio-tls")]
76 Self::Tls(reader) => reader.get_mut().write_all(buf).await,
77 Self::Unix(reader) => reader.get_mut().write_all(buf).await,
78 }
79 }
80
81 pub async fn flush(&mut self) -> std::io::Result<()> {
82 match self {
83 Self::Tcp(reader) => reader.get_mut().flush().await,
84 #[cfg(feature = "tokio-tls")]
85 Self::Tls(reader) => reader.get_mut().flush().await,
86 Self::Unix(reader) => reader.get_mut().flush().await,
87 }
88 }
89
90 pub fn is_tcp_loopback(&self) -> bool {
92 match self {
93 Self::Tcp(r) => r
94 .get_ref()
95 .peer_addr()
96 .map(|addr| addr.ip().is_loopback())
97 .unwrap_or(false),
98 #[cfg(feature = "tokio-tls")]
99 Self::Tls(r) => r
100 .get_ref()
101 .get_ref()
102 .get_ref()
103 .get_ref()
104 .peer_addr()
105 .map(|addr| addr.ip().is_loopback())
106 .unwrap_or(false),
107 Self::Unix(_) => false,
108 }
109 }
110}
111
112async fn read_buf_exact_impl<R: AsyncReadExt + Unpin>(
113 reader: &mut R,
114 mut buf: &mut [MaybeUninit<u8>],
115) -> std::io::Result<()> {
116 while !buf.is_empty() {
117 let n = reader.read_buf(&mut buf).await?;
118 if n == 0 {
119 return Err(std::io::Error::new(
120 std::io::ErrorKind::UnexpectedEof,
121 "failed to fill whole buffer",
122 ));
123 }
124 }
125 Ok(())
126}