zero_mysql/sync/
stream.rs1use core::io::BorrowedCursor;
2use std::io::{BufReader, Read, Write};
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7#[cfg(feature = "sync-tls")]
8use native_tls::TlsStream;
9
10pub enum Stream {
11 Tcp(BufReader<TcpStream>),
12 #[cfg(feature = "sync-tls")]
13 Tls(BufReader<TlsStream<TcpStream>>),
14 #[cfg(unix)]
15 Unix(BufReader<UnixStream>),
16}
17
18impl Stream {
19 pub fn tcp(stream: TcpStream) -> Self {
20 Self::Tcp(BufReader::new(stream))
21 }
22
23 #[cfg(unix)]
24 pub fn unix(stream: UnixStream) -> Self {
25 Self::Unix(BufReader::new(stream))
26 }
27
28 #[cfg(feature = "sync-tls")]
29 pub fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
30 let tcp = match self {
31 Self::Tcp(buf_reader) => buf_reader.into_inner(),
32 #[cfg(feature = "sync-tls")]
33 Self::Tls(_) => {
34 return Err(std::io::Error::new(
35 std::io::ErrorKind::InvalidInput,
36 "Already using TLS",
37 ));
38 }
39 #[cfg(unix)]
40 Self::Unix(_) => {
41 return Err(std::io::Error::new(
42 std::io::ErrorKind::InvalidInput,
43 "TLS not supported for Unix sockets",
44 ));
45 }
46 };
47
48 let connector = native_tls::TlsConnector::new()
49 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
50 let tls_stream = connector
51 .connect(host, tcp)
52 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
53
54 Ok(Self::Tls(BufReader::new(tls_stream)))
55 }
56
57 pub fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
58 match self {
59 Self::Tcp(r) => r.read_exact(buf),
60 #[cfg(feature = "sync-tls")]
61 Self::Tls(r) => r.read_exact(buf),
62 #[cfg(unix)]
63 Self::Unix(r) => r.read_exact(buf),
64 }
65 }
66
67 pub fn read_buf_exact(&mut self, cursor: BorrowedCursor<'_>) -> std::io::Result<()> {
68 match self {
69 Self::Tcp(r) => r.read_buf_exact(cursor),
70 #[cfg(feature = "sync-tls")]
71 Self::Tls(r) => r.read_buf_exact(cursor),
72 #[cfg(unix)]
73 Self::Unix(r) => r.read_buf_exact(cursor),
74 }
75 }
76
77 pub fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
78 match self {
79 Self::Tcp(r) => r.get_mut().write_all(buf),
80 #[cfg(feature = "sync-tls")]
81 Self::Tls(r) => r.get_mut().write_all(buf),
82 #[cfg(unix)]
83 Self::Unix(r) => r.get_mut().write_all(buf),
84 }
85 }
86
87 pub fn flush(&mut self) -> std::io::Result<()> {
88 match self {
89 Self::Tcp(r) => r.get_mut().flush(),
90 #[cfg(feature = "sync-tls")]
91 Self::Tls(r) => r.get_mut().flush(),
92 #[cfg(unix)]
93 Self::Unix(r) => r.get_mut().flush(),
94 }
95 }
96
97 pub fn is_tcp_loopback(&self) -> bool {
99 match self {
100 Self::Tcp(r) => r
101 .get_ref()
102 .peer_addr()
103 .map(|addr| addr.ip().is_loopback())
104 .unwrap_or(false),
105 #[cfg(feature = "sync-tls")]
106 Self::Tls(r) => r
107 .get_ref()
108 .get_ref()
109 .peer_addr()
110 .map(|addr| addr.ip().is_loopback())
111 .unwrap_or(false),
112 #[cfg(unix)]
113 Self::Unix(_) => false,
114 }
115 }
116}