zero_mysql/sync/
stream.rs

1use core::io::BorrowedCursor;
2use std::io::{BufReader, Read, Write};
3use std::net::TcpStream;
4use std::os::unix::net::UnixStream;
5
6#[cfg(feature = "sync-tls")]
7use native_tls::TlsStream;
8
9pub enum Stream {
10    Tcp(BufReader<TcpStream>),
11    #[cfg(feature = "sync-tls")]
12    Tls(BufReader<TlsStream<TcpStream>>),
13    Unix(BufReader<UnixStream>),
14}
15
16impl Stream {
17    pub fn tcp(stream: TcpStream) -> Self {
18        Self::Tcp(BufReader::new(stream))
19    }
20
21    pub fn unix(stream: UnixStream) -> Self {
22        Self::Unix(BufReader::new(stream))
23    }
24
25    #[cfg(feature = "sync-tls")]
26    pub fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
27        let tcp = match self {
28            Self::Tcp(buf_reader) => buf_reader.into_inner(),
29            #[cfg(feature = "sync-tls")]
30            Self::Tls(_) => {
31                return Err(std::io::Error::new(
32                    std::io::ErrorKind::InvalidInput,
33                    "Already using TLS",
34                ));
35            }
36            Self::Unix(_) => {
37                return Err(std::io::Error::new(
38                    std::io::ErrorKind::InvalidInput,
39                    "TLS not supported for Unix sockets",
40                ));
41            }
42        };
43
44        let connector = native_tls::TlsConnector::new()
45            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
46        let tls_stream = connector
47            .connect(host, tcp)
48            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
49
50        Ok(Self::Tls(BufReader::new(tls_stream)))
51    }
52
53    pub fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
54        match self {
55            Self::Tcp(r) => r.read_exact(buf),
56            #[cfg(feature = "sync-tls")]
57            Self::Tls(r) => r.read_exact(buf),
58            Self::Unix(r) => r.read_exact(buf),
59        }
60    }
61
62    pub fn read_buf_exact(&mut self, cursor: BorrowedCursor<'_>) -> std::io::Result<()> {
63        match self {
64            Self::Tcp(r) => r.read_buf_exact(cursor),
65            #[cfg(feature = "sync-tls")]
66            Self::Tls(r) => r.read_buf_exact(cursor),
67            Self::Unix(r) => r.read_buf_exact(cursor),
68        }
69    }
70
71    pub fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
72        match self {
73            Self::Tcp(r) => r.get_mut().write_all(buf),
74            #[cfg(feature = "sync-tls")]
75            Self::Tls(r) => r.get_mut().write_all(buf),
76            Self::Unix(r) => r.get_mut().write_all(buf),
77        }
78    }
79
80    pub fn flush(&mut self) -> std::io::Result<()> {
81        match self {
82            Self::Tcp(r) => r.get_mut().flush(),
83            #[cfg(feature = "sync-tls")]
84            Self::Tls(r) => r.get_mut().flush(),
85            Self::Unix(r) => r.get_mut().flush(),
86        }
87    }
88
89    /// Returns true if this is a TCP connection to a loopback address
90    pub fn is_tcp_loopback(&self) -> bool {
91        match self {
92            Self::Tcp(r) => r
93                .get_ref()
94                .peer_addr()
95                .map(|addr| addr.ip().is_loopback())
96                .unwrap_or(false),
97            #[cfg(feature = "sync-tls")]
98            Self::Tls(r) => r
99                .get_ref()
100                .get_ref()
101                .peer_addr()
102                .map(|addr| addr.ip().is_loopback())
103                .unwrap_or(false),
104            Self::Unix(_) => false,
105        }
106    }
107}