sqlx_core/net/socket/
mod.rs1use std::future::Future;
2use std::io;
3use std::path::Path;
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6
7pub use buffered::{BufferedSocket, WriteBuffer};
8use bytes::BufMut;
9use cfg_if::cfg_if;
10
11use crate::io::ReadBuf;
12
13mod buffered;
14
15pub trait Socket: Send + Sync + Unpin + 'static {
16 fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize>;
17
18 fn try_write(&mut self, buf: &[u8]) -> io::Result<usize>;
19
20 fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
21
22 fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
23
24 fn poll_flush(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
25 Poll::Ready(Ok(()))
27 }
28
29 fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
30
31 fn read<'a, B: ReadBuf>(&'a mut self, buf: &'a mut B) -> Read<'a, Self, B>
32 where
33 Self: Sized,
34 {
35 Read { socket: self, buf }
36 }
37
38 fn write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, Self>
39 where
40 Self: Sized,
41 {
42 Write { socket: self, buf }
43 }
44
45 fn flush(&mut self) -> Flush<'_, Self>
46 where
47 Self: Sized,
48 {
49 Flush { socket: self }
50 }
51
52 fn shutdown(&mut self) -> Shutdown<'_, Self>
53 where
54 Self: Sized,
55 {
56 Shutdown { socket: self }
57 }
58}
59
60pub struct Read<'a, S: ?Sized, B> {
61 socket: &'a mut S,
62 buf: &'a mut B,
63}
64
65impl<S: ?Sized, B> Future for Read<'_, S, B>
66where
67 S: Socket,
68 B: ReadBuf,
69{
70 type Output = io::Result<usize>;
71
72 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
73 let this = &mut *self;
74
75 while this.buf.has_remaining_mut() {
76 match this.socket.try_read(&mut *this.buf) {
77 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
78 ready!(this.socket.poll_read_ready(cx))?;
79 }
80 ready => return Poll::Ready(ready),
81 }
82 }
83
84 Poll::Ready(Ok(0))
85 }
86}
87
88pub struct Write<'a, S: ?Sized> {
89 socket: &'a mut S,
90 buf: &'a [u8],
91}
92
93impl<S: ?Sized> Future for Write<'_, S>
94where
95 S: Socket,
96{
97 type Output = io::Result<usize>;
98
99 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100 let this = &mut *self;
101
102 while !this.buf.is_empty() {
103 match this.socket.try_write(this.buf) {
104 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
105 ready!(this.socket.poll_write_ready(cx))?;
106 }
107 ready => return Poll::Ready(ready),
108 }
109 }
110
111 Poll::Ready(Ok(0))
112 }
113}
114
115pub struct Flush<'a, S: ?Sized> {
116 socket: &'a mut S,
117}
118
119impl<S: Socket + ?Sized> Future for Flush<'_, S> {
120 type Output = io::Result<()>;
121
122 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123 self.socket.poll_flush(cx)
124 }
125}
126
127pub struct Shutdown<'a, S: ?Sized> {
128 socket: &'a mut S,
129}
130
131impl<S: ?Sized> Future for Shutdown<'_, S>
132where
133 S: Socket,
134{
135 type Output = io::Result<()>;
136
137 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138 self.socket.poll_shutdown(cx)
139 }
140}
141
142pub trait WithSocket {
143 type Output;
144
145 fn with_socket<S: Socket>(self, socket: S) -> impl Future<Output = Self::Output> + Send;
146}
147
148pub struct SocketIntoBox;
149
150impl WithSocket for SocketIntoBox {
151 type Output = Box<dyn Socket>;
152
153 async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
154 Box::new(socket)
155 }
156}
157
158impl<S: Socket + ?Sized> Socket for Box<S> {
159 fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize> {
160 (**self).try_read(buf)
161 }
162
163 fn try_write(&mut self, buf: &[u8]) -> io::Result<usize> {
164 (**self).try_write(buf)
165 }
166
167 fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
168 (**self).poll_read_ready(cx)
169 }
170
171 fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
172 (**self).poll_write_ready(cx)
173 }
174
175 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
176 (**self).poll_flush(cx)
177 }
178
179 fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180 (**self).poll_shutdown(cx)
181 }
182}
183
184pub async fn connect_tcp<Ws: WithSocket>(
185 host: &str,
186 port: u16,
187 with_socket: Ws,
188) -> crate::Result<Ws::Output> {
189 #[cfg(feature = "_rt-tokio")]
190 if crate::rt::rt_tokio::available() {
191 return Ok(with_socket
192 .with_socket(tokio::net::TcpStream::connect((host, port)).await?)
193 .await);
194 }
195
196 cfg_if! {
197 if #[cfg(feature = "_rt-async-io")] {
198 Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)
199 } else {
200 crate::rt::missing_rt((host, port, with_socket))
201 }
202 }
203}
204
205#[cfg(feature = "_rt-async-io")]
211async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> {
212 use async_io::Async;
213 use std::net::{IpAddr, TcpStream, ToSocketAddrs};
214
215 let host = host.trim_matches(&['[', ']'][..]);
217
218 if let Ok(addr) = host.parse::<IpAddr>() {
219 return Ok(Async::<TcpStream>::connect((addr, port)).await?);
220 }
221
222 let host = host.to_string();
223
224 let addresses = crate::rt::spawn_blocking(move || {
225 let addr = (host.as_str(), port);
226 ToSocketAddrs::to_socket_addrs(&addr)
227 })
228 .await?;
229
230 let mut last_err = None;
231
232 for socket_addr in addresses {
234 match Async::<TcpStream>::connect(socket_addr).await {
235 Ok(stream) => return Ok(stream),
236 Err(e) => last_err = Some(e),
237 }
238 }
239
240 Err(last_err
243 .unwrap_or_else(|| {
244 io::Error::new(
245 io::ErrorKind::AddrNotAvailable,
246 "Hostname did not resolve to any addresses",
247 )
248 })
249 .into())
250}
251
252pub async fn connect_uds<P: AsRef<Path>, Ws: WithSocket>(
256 path: P,
257 with_socket: Ws,
258) -> crate::Result<Ws::Output> {
259 #[cfg(unix)]
260 {
261 #[cfg(feature = "_rt-tokio")]
262 if crate::rt::rt_tokio::available() {
263 use tokio::net::UnixStream;
264
265 let stream = UnixStream::connect(path).await?;
266
267 return Ok(with_socket.with_socket(stream).await);
268 }
269
270 cfg_if! {
271 if #[cfg(feature = "_rt-async-io")] {
272 use async_io::Async;
273 use std::os::unix::net::UnixStream;
274
275 let stream = Async::<UnixStream>::connect(path).await?;
276
277 Ok(with_socket.with_socket(stream).await)
278 } else {
279 crate::rt::missing_rt((path, with_socket))
280 }
281 }
282 }
283
284 #[cfg(not(unix))]
285 {
286 drop((path, with_socket));
287
288 Err(io::Error::new(
289 io::ErrorKind::Unsupported,
290 "Unix domain sockets are not supported on this platform",
291 )
292 .into())
293 }
294}