Skip to main content

zero_mysql/sync/
stream.rs

1use std::io::{BufReader, Read, Write};
2use std::mem::MaybeUninit;
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7use crate::nightly::read_uninit_exact;
8
9#[cfg(feature = "sync-tls")]
10use native_tls::TlsStream;
11
12pub enum Stream {
13    Tcp(BufReader<TcpStream>),
14    #[cfg(feature = "sync-tls")]
15    Tls(BufReader<TlsStream<TcpStream>>),
16    #[cfg(unix)]
17    Unix(BufReader<UnixStream>),
18}
19
20impl Stream {
21    pub fn tcp(stream: TcpStream) -> Self {
22        Self::Tcp(BufReader::new(stream))
23    }
24
25    #[cfg(unix)]
26    pub fn unix(stream: UnixStream) -> Self {
27        Self::Unix(BufReader::new(stream))
28    }
29
30    #[cfg(feature = "sync-tls")]
31    pub fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
32        let tcp = match self {
33            Self::Tcp(buf_reader) => buf_reader.into_inner(),
34            #[cfg(feature = "sync-tls")]
35            Self::Tls(_) => {
36                return Err(std::io::Error::new(
37                    std::io::ErrorKind::InvalidInput,
38                    "Already using TLS",
39                ));
40            }
41            #[cfg(unix)]
42            Self::Unix(_) => {
43                return Err(std::io::Error::new(
44                    std::io::ErrorKind::InvalidInput,
45                    "TLS not supported for Unix sockets",
46                ));
47            }
48        };
49
50        let connector = native_tls::TlsConnector::new().map_err(std::io::Error::other)?;
51        let tls_stream = connector
52            .connect(host, tcp)
53            .map_err(std::io::Error::other)?;
54
55        Ok(Self::Tls(BufReader::new(tls_stream)))
56    }
57
58    pub fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59        match self {
60            Self::Tcp(r) => r.read_exact(buf),
61            #[cfg(feature = "sync-tls")]
62            Self::Tls(r) => r.read_exact(buf),
63            #[cfg(unix)]
64            Self::Unix(r) => r.read_exact(buf),
65        }
66    }
67
68    pub fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
69        match self {
70            Self::Tcp(r) => read_uninit_exact(r, buf),
71            #[cfg(feature = "sync-tls")]
72            Self::Tls(r) => read_uninit_exact(r, buf),
73            #[cfg(unix)]
74            Self::Unix(r) => read_uninit_exact(r, buf),
75        }
76    }
77
78    pub fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
79        match self {
80            Self::Tcp(r) => r.get_mut().write_all(buf),
81            #[cfg(feature = "sync-tls")]
82            Self::Tls(r) => r.get_mut().write_all(buf),
83            #[cfg(unix)]
84            Self::Unix(r) => r.get_mut().write_all(buf),
85        }
86    }
87
88    pub fn flush(&mut self) -> std::io::Result<()> {
89        match self {
90            Self::Tcp(r) => r.get_mut().flush(),
91            #[cfg(feature = "sync-tls")]
92            Self::Tls(r) => r.get_mut().flush(),
93            #[cfg(unix)]
94            Self::Unix(r) => r.get_mut().flush(),
95        }
96    }
97
98    /// Returns true if this is a TCP connection to a loopback address
99    pub fn is_tcp_loopback(&self) -> bool {
100        match self {
101            Self::Tcp(r) => r
102                .get_ref()
103                .peer_addr()
104                .map(|addr| addr.ip().is_loopback())
105                .unwrap_or(false),
106            #[cfg(feature = "sync-tls")]
107            Self::Tls(r) => r
108                .get_ref()
109                .get_ref()
110                .peer_addr()
111                .map(|addr| addr.ip().is_loopback())
112                .unwrap_or(false),
113            #[cfg(unix)]
114            Self::Unix(_) => false,
115        }
116    }
117}