sqlx_core/net/socket/
mod.rs

1use std::future::Future;
2use std::io;
3use std::path::Path;
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6
7pub use buffered::{BufferedSocket, WriteBuffer};
8use bytes::BufMut;
9use cfg_if::cfg_if;
10
11use crate::io::ReadBuf;
12
13mod buffered;
14
15pub trait Socket: Send + Sync + Unpin + 'static {
16    fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize>;
17
18    fn try_write(&mut self, buf: &[u8]) -> io::Result<usize>;
19
20    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
21
22    fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
23
24    fn poll_flush(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
25        // `flush()` is a no-op for TCP/UDS
26        Poll::Ready(Ok(()))
27    }
28
29    fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
30
31    fn read<'a, B: ReadBuf>(&'a mut self, buf: &'a mut B) -> Read<'a, Self, B>
32    where
33        Self: Sized,
34    {
35        Read { socket: self, buf }
36    }
37
38    fn write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, Self>
39    where
40        Self: Sized,
41    {
42        Write { socket: self, buf }
43    }
44
45    fn flush(&mut self) -> Flush<'_, Self>
46    where
47        Self: Sized,
48    {
49        Flush { socket: self }
50    }
51
52    fn shutdown(&mut self) -> Shutdown<'_, Self>
53    where
54        Self: Sized,
55    {
56        Shutdown { socket: self }
57    }
58}
59
60pub struct Read<'a, S: ?Sized, B> {
61    socket: &'a mut S,
62    buf: &'a mut B,
63}
64
65impl<S: ?Sized, B> Future for Read<'_, S, B>
66where
67    S: Socket,
68    B: ReadBuf,
69{
70    type Output = io::Result<usize>;
71
72    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
73        let this = &mut *self;
74
75        while this.buf.has_remaining_mut() {
76            match this.socket.try_read(&mut *this.buf) {
77                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
78                    ready!(this.socket.poll_read_ready(cx))?;
79                }
80                ready => return Poll::Ready(ready),
81            }
82        }
83
84        Poll::Ready(Ok(0))
85    }
86}
87
88pub struct Write<'a, S: ?Sized> {
89    socket: &'a mut S,
90    buf: &'a [u8],
91}
92
93impl<S: ?Sized> Future for Write<'_, S>
94where
95    S: Socket,
96{
97    type Output = io::Result<usize>;
98
99    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100        let this = &mut *self;
101
102        while !this.buf.is_empty() {
103            match this.socket.try_write(this.buf) {
104                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
105                    ready!(this.socket.poll_write_ready(cx))?;
106                }
107                ready => return Poll::Ready(ready),
108            }
109        }
110
111        Poll::Ready(Ok(0))
112    }
113}
114
115pub struct Flush<'a, S: ?Sized> {
116    socket: &'a mut S,
117}
118
119impl<S: Socket + ?Sized> Future for Flush<'_, S> {
120    type Output = io::Result<()>;
121
122    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123        self.socket.poll_flush(cx)
124    }
125}
126
127pub struct Shutdown<'a, S: ?Sized> {
128    socket: &'a mut S,
129}
130
131impl<S: ?Sized> Future for Shutdown<'_, S>
132where
133    S: Socket,
134{
135    type Output = io::Result<()>;
136
137    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138        self.socket.poll_shutdown(cx)
139    }
140}
141
142pub trait WithSocket {
143    type Output;
144
145    fn with_socket<S: Socket>(self, socket: S) -> impl Future<Output = Self::Output> + Send;
146}
147
148pub struct SocketIntoBox;
149
150impl WithSocket for SocketIntoBox {
151    type Output = Box<dyn Socket>;
152
153    async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
154        Box::new(socket)
155    }
156}
157
158impl<S: Socket + ?Sized> Socket for Box<S> {
159    fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize> {
160        (**self).try_read(buf)
161    }
162
163    fn try_write(&mut self, buf: &[u8]) -> io::Result<usize> {
164        (**self).try_write(buf)
165    }
166
167    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
168        (**self).poll_read_ready(cx)
169    }
170
171    fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
172        (**self).poll_write_ready(cx)
173    }
174
175    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
176        (**self).poll_flush(cx)
177    }
178
179    fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180        (**self).poll_shutdown(cx)
181    }
182}
183
184pub async fn connect_tcp<Ws: WithSocket>(
185    host: &str,
186    port: u16,
187    with_socket: Ws,
188) -> crate::Result<Ws::Output> {
189    #[cfg(feature = "_rt-tokio")]
190    if crate::rt::rt_tokio::available() {
191        return Ok(with_socket
192            .with_socket(tokio::net::TcpStream::connect((host, port)).await?)
193            .await);
194    }
195
196    cfg_if! {
197        if #[cfg(feature = "_rt-async-io")] {
198            Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)
199        } else {
200            crate::rt::missing_rt((host, port, with_socket))
201        }
202    }
203}
204
205/// Open a TCP socket to `host` and `port`.
206///
207/// If `host` is a hostname, attempt to connect to each address it resolves to.
208///
209/// This implements the same behavior as [`tokio::net::TcpStream::connect()`].
210#[cfg(feature = "_rt-async-io")]
211async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> {
212    use async_io::Async;
213    use std::net::{IpAddr, TcpStream, ToSocketAddrs};
214
215    // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
216    let host = host.trim_matches(&['[', ']'][..]);
217
218    if let Ok(addr) = host.parse::<IpAddr>() {
219        return Ok(Async::<TcpStream>::connect((addr, port)).await?);
220    }
221
222    let host = host.to_string();
223
224    let addresses = crate::rt::spawn_blocking(move || {
225        let addr = (host.as_str(), port);
226        ToSocketAddrs::to_socket_addrs(&addr)
227    })
228    .await?;
229
230    let mut last_err = None;
231
232    // Loop through all the Socket Addresses that the hostname resolves to
233    for socket_addr in addresses {
234        match Async::<TcpStream>::connect(socket_addr).await {
235            Ok(stream) => return Ok(stream),
236            Err(e) => last_err = Some(e),
237        }
238    }
239
240    // If we reach this point, it means we failed to connect to any of the addresses.
241    // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.
242    Err(last_err
243        .unwrap_or_else(|| {
244            io::Error::new(
245                io::ErrorKind::AddrNotAvailable,
246                "Hostname did not resolve to any addresses",
247            )
248        })
249        .into())
250}
251
252/// Connect a Unix Domain Socket at the given path.
253///
254/// Returns an error if Unix Domain Sockets are not supported on this platform.
255pub async fn connect_uds<P: AsRef<Path>, Ws: WithSocket>(
256    path: P,
257    with_socket: Ws,
258) -> crate::Result<Ws::Output> {
259    #[cfg(unix)]
260    {
261        #[cfg(feature = "_rt-tokio")]
262        if crate::rt::rt_tokio::available() {
263            use tokio::net::UnixStream;
264
265            let stream = UnixStream::connect(path).await?;
266
267            return Ok(with_socket.with_socket(stream).await);
268        }
269
270        cfg_if! {
271            if #[cfg(feature = "_rt-async-io")] {
272                use async_io::Async;
273                use std::os::unix::net::UnixStream;
274
275                let stream = Async::<UnixStream>::connect(path).await?;
276
277                Ok(with_socket.with_socket(stream).await)
278            } else {
279                crate::rt::missing_rt((path, with_socket))
280            }
281        }
282    }
283
284    #[cfg(not(unix))]
285    {
286        drop((path, with_socket));
287
288        Err(io::Error::new(
289            io::ErrorKind::Unsupported,
290            "Unix domain sockets are not supported on this platform",
291        )
292        .into())
293    }
294}