tokio_unix_tcp/
stream.rs

1/*
2 * Copyright (c) 2023, networkException <git@nwex.de>
3 *
4 * SPDX-License-Identifier: BSD-2-Clause OR MIT
5 */
6
7#[cfg(unix)]
8use tokio::net::UnixStream;
9
10use std::io;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tokio::net::TcpStream;
16
17use crate::NamedSocketAddr;
18use crate::SocketAddr;
19
20#[derive(Debug)]
21pub enum Stream {
22    Tcp(TcpStream),
23    #[cfg(unix)]
24    Unix(UnixStream),
25}
26
27impl From<TcpStream> for Stream {
28    fn from(tcp_stream: TcpStream) -> Self {
29        Stream::Tcp(tcp_stream)
30    }
31}
32
33#[cfg(unix)]
34impl From<UnixStream> for Stream {
35    fn from(unix_stream: UnixStream) -> Self {
36        Stream::Unix(unix_stream)
37    }
38}
39
40impl Stream {
41    pub async fn connect(named_socket_addr: &NamedSocketAddr) -> io::Result<Self> {
42        match named_socket_addr {
43            NamedSocketAddr::Inet(inet_socket_addr) => TcpStream::connect(inet_socket_addr).await.map(Stream::Tcp),
44            #[cfg(unix)]
45            NamedSocketAddr::Unix(path) => UnixStream::connect(path).await.map(Stream::Unix)
46        }
47    }
48
49    pub fn local_addr(&self) -> io::Result<SocketAddr> {
50        match self {
51            Stream::Tcp(tcp_stream) => tcp_stream.local_addr().map(SocketAddr::Inet),
52            #[cfg(unix)]
53            Stream::Unix(unix_stream) => Ok(SocketAddr::Unix(unix_stream.local_addr()?.into())),
54        }
55    }
56
57    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
58        match self {
59            Stream::Tcp(tcp_stream) => tcp_stream.peer_addr().map(SocketAddr::Inet),
60            #[cfg(unix)]
61            Stream::Unix(unix_stream) => Ok(SocketAddr::Unix(unix_stream.local_addr()?.into())),
62        }
63    }
64}
65
66impl AsyncRead for Stream {
67    fn poll_read(
68        self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70        buf: &mut ReadBuf<'_>,
71    ) -> Poll<io::Result<()>> {
72        match Pin::into_inner(self) {
73            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_read(cx, buf),
74            #[cfg(unix)]
75            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_read(cx, buf),
76        }
77    }
78}
79
80impl AsyncWrite for Stream {
81    fn poll_write(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84        buf: &[u8],
85    ) -> Poll<io::Result<usize>> {
86        match Pin::into_inner(self) {
87            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_write(cx, buf),
88            #[cfg(unix)]
89            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_write(cx, buf),
90        }
91    }
92
93    fn poll_write_vectored(
94        self: Pin<&mut Self>,
95        cx: &mut Context<'_>,
96        bufs: &[io::IoSlice<'_>],
97    ) -> Poll<io::Result<usize>> {
98        match Pin::into_inner(self) {
99            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_write_vectored(cx, bufs),
100            #[cfg(unix)]
101            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_write_vectored(cx, bufs),
102        }
103    }
104
105    fn is_write_vectored(&self) -> bool {
106        match self {
107            Stream::Tcp(tcp_stream) => tcp_stream.is_write_vectored(),
108            #[cfg(unix)]
109            Stream::Unix(unix_stream) => unix_stream.is_write_vectored(),
110        }
111    }
112
113    #[inline]
114    fn poll_flush(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<io::Result<()>> {
115        match Pin::into_inner(self) {
116            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_flush(context),
117            #[cfg(unix)]
118            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_flush(context),
119        }
120    }
121
122    fn poll_shutdown(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<io::Result<()>> {
123        match Pin::into_inner(self) {
124            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_shutdown(context),
125            #[cfg(unix)]
126            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_shutdown(context),
127        }
128    }
129}