zero_mysql/sync/
stream.rs1use std::io::{BufReader, Read, Write};
2use std::mem::MaybeUninit;
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7use crate::nightly::read_uninit_exact;
8
9#[cfg(feature = "sync-tls")]
10use native_tls::TlsStream;
11
12pub enum Stream {
13 Tcp(BufReader<TcpStream>),
14 #[cfg(feature = "sync-tls")]
15 Tls(BufReader<TlsStream<TcpStream>>),
16 #[cfg(unix)]
17 Unix(BufReader<UnixStream>),
18}
19
20impl Stream {
21 pub fn tcp(stream: TcpStream) -> Self {
22 Self::Tcp(BufReader::new(stream))
23 }
24
25 #[cfg(unix)]
26 pub fn unix(stream: UnixStream) -> Self {
27 Self::Unix(BufReader::new(stream))
28 }
29
30 #[cfg(feature = "sync-tls")]
31 pub fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
32 let tcp = match self {
33 Self::Tcp(buf_reader) => buf_reader.into_inner(),
34 #[cfg(feature = "sync-tls")]
35 Self::Tls(_) => {
36 return Err(std::io::Error::new(
37 std::io::ErrorKind::InvalidInput,
38 "Already using TLS",
39 ));
40 }
41 #[cfg(unix)]
42 Self::Unix(_) => {
43 return Err(std::io::Error::new(
44 std::io::ErrorKind::InvalidInput,
45 "TLS not supported for Unix sockets",
46 ));
47 }
48 };
49
50 let connector = native_tls::TlsConnector::new().map_err(std::io::Error::other)?;
51 let tls_stream = connector
52 .connect(host, tcp)
53 .map_err(std::io::Error::other)?;
54
55 Ok(Self::Tls(BufReader::new(tls_stream)))
56 }
57
58 pub fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59 match self {
60 Self::Tcp(r) => r.read_exact(buf),
61 #[cfg(feature = "sync-tls")]
62 Self::Tls(r) => r.read_exact(buf),
63 #[cfg(unix)]
64 Self::Unix(r) => r.read_exact(buf),
65 }
66 }
67
68 pub fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
69 match self {
70 Self::Tcp(r) => read_uninit_exact(r, buf),
71 #[cfg(feature = "sync-tls")]
72 Self::Tls(r) => read_uninit_exact(r, buf),
73 #[cfg(unix)]
74 Self::Unix(r) => read_uninit_exact(r, buf),
75 }
76 }
77
78 pub fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
79 match self {
80 Self::Tcp(r) => r.get_mut().write_all(buf),
81 #[cfg(feature = "sync-tls")]
82 Self::Tls(r) => r.get_mut().write_all(buf),
83 #[cfg(unix)]
84 Self::Unix(r) => r.get_mut().write_all(buf),
85 }
86 }
87
88 pub fn flush(&mut self) -> std::io::Result<()> {
89 match self {
90 Self::Tcp(r) => r.get_mut().flush(),
91 #[cfg(feature = "sync-tls")]
92 Self::Tls(r) => r.get_mut().flush(),
93 #[cfg(unix)]
94 Self::Unix(r) => r.get_mut().flush(),
95 }
96 }
97
98 pub fn is_tcp_loopback(&self) -> bool {
100 match self {
101 Self::Tcp(r) => r
102 .get_ref()
103 .peer_addr()
104 .map(|addr| addr.ip().is_loopback())
105 .unwrap_or(false),
106 #[cfg(feature = "sync-tls")]
107 Self::Tls(r) => r
108 .get_ref()
109 .get_ref()
110 .peer_addr()
111 .map(|addr| addr.ip().is_loopback())
112 .unwrap_or(false),
113 #[cfg(unix)]
114 Self::Unix(_) => false,
115 }
116 }
117}