zero_mysql/sync/
stream.rs

1use std::io::{BufReader, Read, Write};
2use std::mem::MaybeUninit;
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7use crate::nightly::read_uninit_exact;
8
9#[cfg(feature = "sync-tls")]
10use native_tls::TlsStream;
11
12pub enum Stream {
13    Tcp(BufReader<TcpStream>),
14    #[cfg(feature = "sync-tls")]
15    Tls(BufReader<TlsStream<TcpStream>>),
16    #[cfg(unix)]
17    Unix(BufReader<UnixStream>),
18}
19
20impl Stream {
21    pub fn tcp(stream: TcpStream) -> Self {
22        Self::Tcp(BufReader::new(stream))
23    }
24
25    #[cfg(unix)]
26    pub fn unix(stream: UnixStream) -> Self {
27        Self::Unix(BufReader::new(stream))
28    }
29
30    #[cfg(feature = "sync-tls")]
31    pub fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
32        let tcp = match self {
33            Self::Tcp(buf_reader) => buf_reader.into_inner(),
34            #[cfg(feature = "sync-tls")]
35            Self::Tls(_) => {
36                return Err(std::io::Error::new(
37                    std::io::ErrorKind::InvalidInput,
38                    "Already using TLS",
39                ));
40            }
41            #[cfg(unix)]
42            Self::Unix(_) => {
43                return Err(std::io::Error::new(
44                    std::io::ErrorKind::InvalidInput,
45                    "TLS not supported for Unix sockets",
46                ));
47            }
48        };
49
50        let connector = native_tls::TlsConnector::new()
51            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
52        let tls_stream = connector
53            .connect(host, tcp)
54            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
55
56        Ok(Self::Tls(BufReader::new(tls_stream)))
57    }
58
59    pub fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
60        match self {
61            Self::Tcp(r) => r.read_exact(buf),
62            #[cfg(feature = "sync-tls")]
63            Self::Tls(r) => r.read_exact(buf),
64            #[cfg(unix)]
65            Self::Unix(r) => r.read_exact(buf),
66        }
67    }
68
69    pub fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
70        match self {
71            Self::Tcp(r) => read_uninit_exact(r, buf),
72            #[cfg(feature = "sync-tls")]
73            Self::Tls(r) => read_uninit_exact(r, buf),
74            #[cfg(unix)]
75            Self::Unix(r) => read_uninit_exact(r, buf),
76        }
77    }
78
79    pub fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
80        match self {
81            Self::Tcp(r) => r.get_mut().write_all(buf),
82            #[cfg(feature = "sync-tls")]
83            Self::Tls(r) => r.get_mut().write_all(buf),
84            #[cfg(unix)]
85            Self::Unix(r) => r.get_mut().write_all(buf),
86        }
87    }
88
89    pub fn flush(&mut self) -> std::io::Result<()> {
90        match self {
91            Self::Tcp(r) => r.get_mut().flush(),
92            #[cfg(feature = "sync-tls")]
93            Self::Tls(r) => r.get_mut().flush(),
94            #[cfg(unix)]
95            Self::Unix(r) => r.get_mut().flush(),
96        }
97    }
98
99    /// Returns true if this is a TCP connection to a loopback address
100    pub fn is_tcp_loopback(&self) -> bool {
101        match self {
102            Self::Tcp(r) => r
103                .get_ref()
104                .peer_addr()
105                .map(|addr| addr.ip().is_loopback())
106                .unwrap_or(false),
107            #[cfg(feature = "sync-tls")]
108            Self::Tls(r) => r
109                .get_ref()
110                .get_ref()
111                .peer_addr()
112                .map(|addr| addr.ip().is_loopback())
113                .unwrap_or(false),
114            #[cfg(unix)]
115            Self::Unix(_) => false,
116        }
117    }
118}