1#[cfg(unix)]
8use tokio::net::UnixStream;
9
10use std::io;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tokio::net::TcpStream;
16
17use crate::NamedSocketAddr;
18use crate::SocketAddr;
19
20#[derive(Debug)]
21pub enum Stream {
22 Tcp(TcpStream),
23 #[cfg(unix)]
24 Unix(UnixStream),
25}
26
27impl From<TcpStream> for Stream {
28 fn from(tcp_stream: TcpStream) -> Self {
29 Stream::Tcp(tcp_stream)
30 }
31}
32
33#[cfg(unix)]
34impl From<UnixStream> for Stream {
35 fn from(unix_stream: UnixStream) -> Self {
36 Stream::Unix(unix_stream)
37 }
38}
39
40impl Stream {
41 pub async fn connect(named_socket_addr: &NamedSocketAddr) -> io::Result<Self> {
42 match named_socket_addr {
43 NamedSocketAddr::Inet(inet_socket_addr) => TcpStream::connect(inet_socket_addr).await.map(Stream::Tcp),
44 #[cfg(unix)]
45 NamedSocketAddr::Unix(path) => UnixStream::connect(path).await.map(Stream::Unix)
46 }
47 }
48
49 pub fn local_addr(&self) -> io::Result<SocketAddr> {
50 match self {
51 Stream::Tcp(tcp_stream) => tcp_stream.local_addr().map(SocketAddr::Inet),
52 #[cfg(unix)]
53 Stream::Unix(unix_stream) => Ok(SocketAddr::Unix(unix_stream.local_addr()?.into())),
54 }
55 }
56
57 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
58 match self {
59 Stream::Tcp(tcp_stream) => tcp_stream.peer_addr().map(SocketAddr::Inet),
60 #[cfg(unix)]
61 Stream::Unix(unix_stream) => Ok(SocketAddr::Unix(unix_stream.local_addr()?.into())),
62 }
63 }
64}
65
66impl AsyncRead for Stream {
67 fn poll_read(
68 self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 buf: &mut ReadBuf<'_>,
71 ) -> Poll<io::Result<()>> {
72 match Pin::into_inner(self) {
73 Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_read(cx, buf),
74 #[cfg(unix)]
75 Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_read(cx, buf),
76 }
77 }
78}
79
80impl AsyncWrite for Stream {
81 fn poll_write(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 buf: &[u8],
85 ) -> Poll<io::Result<usize>> {
86 match Pin::into_inner(self) {
87 Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_write(cx, buf),
88 #[cfg(unix)]
89 Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_write(cx, buf),
90 }
91 }
92
93 fn poll_write_vectored(
94 self: Pin<&mut Self>,
95 cx: &mut Context<'_>,
96 bufs: &[io::IoSlice<'_>],
97 ) -> Poll<io::Result<usize>> {
98 match Pin::into_inner(self) {
99 Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_write_vectored(cx, bufs),
100 #[cfg(unix)]
101 Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_write_vectored(cx, bufs),
102 }
103 }
104
105 fn is_write_vectored(&self) -> bool {
106 match self {
107 Stream::Tcp(tcp_stream) => tcp_stream.is_write_vectored(),
108 #[cfg(unix)]
109 Stream::Unix(unix_stream) => unix_stream.is_write_vectored(),
110 }
111 }
112
113 #[inline]
114 fn poll_flush(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<io::Result<()>> {
115 match Pin::into_inner(self) {
116 Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_flush(context),
117 #[cfg(unix)]
118 Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_flush(context),
119 }
120 }
121
122 fn poll_shutdown(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<io::Result<()>> {
123 match Pin::into_inner(self) {
124 Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_shutdown(context),
125 #[cfg(unix)]
126 Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_shutdown(context),
127 }
128 }
129}