zero_mysql/sync/
stream.rs

1use core::io::BorrowedCursor;
2use std::io::{BufReader, Read, Write};
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7#[cfg(feature = "sync-tls")]
8use native_tls::TlsStream;
9
10pub enum Stream {
11    Tcp(BufReader<TcpStream>),
12    #[cfg(feature = "sync-tls")]
13    Tls(BufReader<TlsStream<TcpStream>>),
14    #[cfg(unix)]
15    Unix(BufReader<UnixStream>),
16}
17
18impl Stream {
19    pub fn tcp(stream: TcpStream) -> Self {
20        Self::Tcp(BufReader::new(stream))
21    }
22
23    #[cfg(unix)]
24    pub fn unix(stream: UnixStream) -> Self {
25        Self::Unix(BufReader::new(stream))
26    }
27
28    #[cfg(feature = "sync-tls")]
29    pub fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
30        let tcp = match self {
31            Self::Tcp(buf_reader) => buf_reader.into_inner(),
32            #[cfg(feature = "sync-tls")]
33            Self::Tls(_) => {
34                return Err(std::io::Error::new(
35                    std::io::ErrorKind::InvalidInput,
36                    "Already using TLS",
37                ));
38            }
39            #[cfg(unix)]
40            Self::Unix(_) => {
41                return Err(std::io::Error::new(
42                    std::io::ErrorKind::InvalidInput,
43                    "TLS not supported for Unix sockets",
44                ));
45            }
46        };
47
48        let connector = native_tls::TlsConnector::new()
49            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
50        let tls_stream = connector
51            .connect(host, tcp)
52            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
53
54        Ok(Self::Tls(BufReader::new(tls_stream)))
55    }
56
57    pub fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
58        match self {
59            Self::Tcp(r) => r.read_exact(buf),
60            #[cfg(feature = "sync-tls")]
61            Self::Tls(r) => r.read_exact(buf),
62            #[cfg(unix)]
63            Self::Unix(r) => r.read_exact(buf),
64        }
65    }
66
67    pub fn read_buf_exact(&mut self, cursor: BorrowedCursor<'_>) -> std::io::Result<()> {
68        match self {
69            Self::Tcp(r) => r.read_buf_exact(cursor),
70            #[cfg(feature = "sync-tls")]
71            Self::Tls(r) => r.read_buf_exact(cursor),
72            #[cfg(unix)]
73            Self::Unix(r) => r.read_buf_exact(cursor),
74        }
75    }
76
77    pub fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
78        match self {
79            Self::Tcp(r) => r.get_mut().write_all(buf),
80            #[cfg(feature = "sync-tls")]
81            Self::Tls(r) => r.get_mut().write_all(buf),
82            #[cfg(unix)]
83            Self::Unix(r) => r.get_mut().write_all(buf),
84        }
85    }
86
87    pub fn flush(&mut self) -> std::io::Result<()> {
88        match self {
89            Self::Tcp(r) => r.get_mut().flush(),
90            #[cfg(feature = "sync-tls")]
91            Self::Tls(r) => r.get_mut().flush(),
92            #[cfg(unix)]
93            Self::Unix(r) => r.get_mut().flush(),
94        }
95    }
96
97    /// Returns true if this is a TCP connection to a loopback address
98    pub fn is_tcp_loopback(&self) -> bool {
99        match self {
100            Self::Tcp(r) => r
101                .get_ref()
102                .peer_addr()
103                .map(|addr| addr.ip().is_loopback())
104                .unwrap_or(false),
105            #[cfg(feature = "sync-tls")]
106            Self::Tls(r) => r
107                .get_ref()
108                .get_ref()
109                .peer_addr()
110                .map(|addr| addr.ip().is_loopback())
111                .unwrap_or(false),
112            #[cfg(unix)]
113            Self::Unix(_) => false,
114        }
115    }
116}