1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
#![allow(dead_code)] use std::io; use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; use sqlx_rt::{AsyncRead, AsyncWrite, TcpStream}; #[derive(Debug)] pub enum Socket { Tcp(TcpStream), #[cfg(unix)] Unix(sqlx_rt::UnixStream), } impl Socket { pub async fn connect_tcp(host: &str, port: u16) -> io::Result<Self> { TcpStream::connect((host, port)).await.map(Socket::Tcp) } #[cfg(unix)] pub async fn connect_uds(path: impl AsRef<Path>) -> io::Result<Self> { sqlx_rt::UnixStream::connect(path.as_ref()) .await .map(Socket::Unix) } #[cfg(not(unix))] pub async fn connect_uds(_: impl AsRef<Path>) -> io::Result<Self> { Err(io::Error::new( io::ErrorKind::Other, "Unix domain sockets are not supported outside Unix platforms.", )) } pub async fn shutdown(&mut self) -> io::Result<()> { #[cfg(feature = "_rt-async-std")] { use std::net::Shutdown; match self { Socket::Tcp(s) => s.shutdown(Shutdown::Both), #[cfg(unix)] Socket::Unix(s) => s.shutdown(Shutdown::Both), } } #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))] { use sqlx_rt::AsyncWriteExt; match self { Socket::Tcp(s) => s.shutdown().await, #[cfg(unix)] Socket::Unix(s) => s.shutdown().await, } } } } impl AsyncRead for Socket { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut super::PollReadBuf<'_>, ) -> Poll<io::Result<super::PollReadOut>> { match &mut *self { Socket::Tcp(s) => Pin::new(s).poll_read(cx, buf), #[cfg(unix)] Socket::Unix(s) => Pin::new(s).poll_read(cx, buf), } } } impl AsyncWrite for Socket { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { match &mut *self { Socket::Tcp(s) => Pin::new(s).poll_write(cx, buf), #[cfg(unix)] Socket::Unix(s) => Pin::new(s).poll_write(cx, buf), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { match &mut *self { Socket::Tcp(s) => Pin::new(s).poll_flush(cx), #[cfg(unix)] Socket::Unix(s) => Pin::new(s).poll_flush(cx), } } #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { match &mut *self { Socket::Tcp(s) => Pin::new(s).poll_shutdown(cx), #[cfg(unix)] Socket::Unix(s) => Pin::new(s).poll_shutdown(cx), } } #[cfg(feature = "_rt-async-std")] fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { match &mut *self { Socket::Tcp(s) => Pin::new(s).poll_close(cx), #[cfg(unix)] Socket::Unix(s) => Pin::new(s).poll_close(cx), } } }