zero_mysql/tokio/
stream.rs1use core::mem::MaybeUninit;
2use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
3use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7#[cfg(feature = "tokio-tls")]
8use tokio_native_tls::TlsStream;
9
10pub enum Stream {
11 Tcp(BufReader<TcpStream>),
12 #[cfg(feature = "tokio-tls")]
13 Tls(BufReader<TlsStream<TcpStream>>),
14 #[cfg(unix)]
15 Unix(BufReader<UnixStream>),
16}
17
18impl Stream {
19 pub fn tcp(stream: TcpStream) -> Self {
20 Self::Tcp(BufReader::new(stream))
21 }
22
23 #[cfg(unix)]
24 pub fn unix(stream: UnixStream) -> Self {
25 Self::Unix(BufReader::new(stream))
26 }
27
28 #[cfg(feature = "tokio-tls")]
29 pub async fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
30 let tcp = match self {
31 Self::Tcp(buf_reader) => buf_reader.into_inner(),
32 #[cfg(feature = "tokio-tls")]
33 Self::Tls(_) => {
34 return Err(std::io::Error::new(
35 std::io::ErrorKind::InvalidInput,
36 "Already using TLS",
37 ));
38 }
39 #[cfg(unix)]
40 Self::Unix(_) => {
41 return Err(std::io::Error::new(
42 std::io::ErrorKind::InvalidInput,
43 "TLS not supported for Unix sockets",
44 ));
45 }
46 };
47
48 let connector = native_tls::TlsConnector::new().map_err(std::io::Error::other)?;
49 let connector = tokio_native_tls::TlsConnector::from(connector);
50 let tls_stream = connector
51 .connect(host, tcp)
52 .await
53 .map_err(std::io::Error::other)?;
54
55 Ok(Self::Tls(BufReader::new(tls_stream)))
56 }
57
58 pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59 match self {
60 Self::Tcp(reader) => reader.read_exact(buf).await.map(|_| ()),
61 #[cfg(feature = "tokio-tls")]
62 Self::Tls(reader) => reader.read_exact(buf).await.map(|_| ()),
63 #[cfg(unix)]
64 Self::Unix(reader) => reader.read_exact(buf).await.map(|_| ()),
65 }
66 }
67
68 pub async fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
69 match self {
70 Self::Tcp(reader) => read_buf_exact_impl(reader, buf).await,
71 #[cfg(feature = "tokio-tls")]
72 Self::Tls(reader) => read_buf_exact_impl(reader, buf).await,
73 #[cfg(unix)]
74 Self::Unix(reader) => read_buf_exact_impl(reader, buf).await,
75 }
76 }
77
78 pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
79 match self {
80 Self::Tcp(reader) => reader.get_mut().write_all(buf).await,
81 #[cfg(feature = "tokio-tls")]
82 Self::Tls(reader) => reader.get_mut().write_all(buf).await,
83 #[cfg(unix)]
84 Self::Unix(reader) => reader.get_mut().write_all(buf).await,
85 }
86 }
87
88 pub async fn flush(&mut self) -> std::io::Result<()> {
89 match self {
90 Self::Tcp(reader) => reader.get_mut().flush().await,
91 #[cfg(feature = "tokio-tls")]
92 Self::Tls(reader) => reader.get_mut().flush().await,
93 #[cfg(unix)]
94 Self::Unix(reader) => reader.get_mut().flush().await,
95 }
96 }
97
98 pub fn is_tcp_loopback(&self) -> bool {
100 match self {
101 Self::Tcp(r) => r
102 .get_ref()
103 .peer_addr()
104 .map(|addr| addr.ip().is_loopback())
105 .unwrap_or(false),
106 #[cfg(feature = "tokio-tls")]
107 Self::Tls(r) => r
108 .get_ref()
109 .get_ref()
110 .get_ref()
111 .get_ref()
112 .peer_addr()
113 .map(|addr| addr.ip().is_loopback())
114 .unwrap_or(false),
115 #[cfg(unix)]
116 Self::Unix(_) => false,
117 }
118 }
119}
120
121async fn read_buf_exact_impl<R: AsyncReadExt + Unpin>(
122 reader: &mut R,
123 mut buf: &mut [MaybeUninit<u8>],
124) -> std::io::Result<()> {
125 while !buf.is_empty() {
126 let n = reader.read_buf(&mut buf).await?;
127 if n == 0 {
128 return Err(std::io::Error::new(
129 std::io::ErrorKind::UnexpectedEof,
130 "failed to fill whole buffer",
131 ));
132 }
133 }
134 Ok(())
135}