sqlx_core/net/
socket.rs

1#![allow(dead_code)]
2
3use std::io;
4use std::net::SocketAddr;
5use std::path::Path;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use sqlx_rt::{AsyncRead, AsyncWrite, TcpStream};
10
11#[derive(Debug)]
12pub enum Socket {
13    Tcp(TcpStream),
14
15    #[cfg(unix)]
16    Unix(sqlx_rt::UnixStream),
17}
18
19impl Socket {
20    pub async fn connect_tcp(host: &str, port: u16) -> io::Result<Self> {
21        // Trim square brackets from host if it's an IPv6 address as the `url` crate doesn't do that.
22        TcpStream::connect((host.trim_matches(|c| c == '[' || c == ']'), port))
23            .await
24            .map(Socket::Tcp)
25    }
26
27    #[cfg(unix)]
28    pub async fn connect_uds(path: impl AsRef<Path>) -> io::Result<Self> {
29        sqlx_rt::UnixStream::connect(path.as_ref())
30            .await
31            .map(Socket::Unix)
32    }
33
34    pub fn local_addr(&self) -> Option<SocketAddr> {
35        match self {
36            Self::Tcp(tcp) => tcp.local_addr().ok(),
37            #[cfg(unix)]
38            Self::Unix(_) => None,
39        }
40    }
41
42    #[cfg(not(unix))]
43    pub async fn connect_uds(_: impl AsRef<Path>) -> io::Result<Self> {
44        Err(io::Error::new(
45            io::ErrorKind::Other,
46            "Unix domain sockets are not supported outside Unix platforms.",
47        ))
48    }
49
50    pub async fn shutdown(&mut self) -> io::Result<()> {
51        #[cfg(feature = "_rt-async-std")]
52        {
53            use std::net::Shutdown;
54
55            match self {
56                Socket::Tcp(s) => s.shutdown(Shutdown::Both),
57
58                #[cfg(unix)]
59                Socket::Unix(s) => s.shutdown(Shutdown::Both),
60            }
61        }
62
63        #[cfg(feature = "_rt-tokio")]
64        {
65            use sqlx_rt::AsyncWriteExt;
66
67            match self {
68                Socket::Tcp(s) => s.shutdown().await,
69
70                #[cfg(unix)]
71                Socket::Unix(s) => s.shutdown().await,
72            }
73        }
74    }
75}
76
77impl AsyncRead for Socket {
78    fn poll_read(
79        mut self: Pin<&mut Self>,
80        cx: &mut Context<'_>,
81        buf: &mut super::PollReadBuf<'_>,
82    ) -> Poll<io::Result<super::PollReadOut>> {
83        match &mut *self {
84            Socket::Tcp(s) => Pin::new(s).poll_read(cx, buf),
85
86            #[cfg(unix)]
87            Socket::Unix(s) => Pin::new(s).poll_read(cx, buf),
88        }
89    }
90}
91
92impl AsyncWrite for Socket {
93    fn poll_write(
94        mut self: Pin<&mut Self>,
95        cx: &mut Context<'_>,
96        buf: &[u8],
97    ) -> Poll<io::Result<usize>> {
98        match &mut *self {
99            Socket::Tcp(s) => Pin::new(s).poll_write(cx, buf),
100
101            #[cfg(unix)]
102            Socket::Unix(s) => Pin::new(s).poll_write(cx, buf),
103        }
104    }
105
106    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
107        match &mut *self {
108            Socket::Tcp(s) => Pin::new(s).poll_flush(cx),
109
110            #[cfg(unix)]
111            Socket::Unix(s) => Pin::new(s).poll_flush(cx),
112        }
113    }
114
115    #[cfg(feature = "_rt-tokio")]
116    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
117        match &mut *self {
118            Socket::Tcp(s) => Pin::new(s).poll_shutdown(cx),
119
120            #[cfg(unix)]
121            Socket::Unix(s) => Pin::new(s).poll_shutdown(cx),
122        }
123    }
124
125    #[cfg(feature = "_rt-async-std")]
126    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
127        match &mut *self {
128            Socket::Tcp(s) => Pin::new(s).poll_close(cx),
129
130            #[cfg(unix)]
131            Socket::Unix(s) => Pin::new(s).poll_close(cx),
132        }
133    }
134}