Skip to main content

uni_stream/
stream.rs

1//! Stream abstractions and implementations for TCP and UDP.
2
3use std::net::SocketAddr;
4#[cfg(not(target_os = "windows"))]
5use std::os::fd::AsFd;
6#[cfg(target_os = "windows")]
7use std::os::windows::io::AsSocket;
8use std::time::Duration;
9
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11use tokio::net::tcp::{ReadHalf, WriteHalf};
12use tokio::net::{TcpListener, TcpStream};
13
14use crate::addr::{each_addr, ToSocketAddrs};
15use crate::udp::{UdpListener, UdpStream, UdpStreamReadHalf, UdpStreamWriteHalf};
16
17type Result<T, E = std::io::Error> = std::result::Result<T, E>;
18
19/// A stream that can be split into read and write halves.
20pub trait StreamSplit {
21    /// Reader half type.
22    type ReaderRef<'a>: AsyncReadExt + Send + Unpin
23    where
24        Self: 'a;
25    /// Writer half type.
26    type WriterRef<'a>: AsyncWriteExt + Send + Unpin
27    where
28        Self: 'a;
29    /// Owned reader half type.
30    type ReaderOwned: AsyncReadExt + Send + Unpin + 'static;
31    /// Owned writer half type.
32    type WriterOwned: AsyncWriteExt + Send + Unpin + 'static;
33
34    /// Split into reader and writer halves.
35    fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>);
36
37    /// Split into owned reader and writer halves.
38    fn into_split(self) -> (Self::ReaderOwned, Self::WriterOwned);
39}
40
41/// Marker trait for streams used in the system.
42pub trait NetworkStream:
43    StreamSplit + AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static
44{
45}
46
47macro_rules! gen_stream_impl {
48    ($struct_name:ident, $inner_ty:ty) => {
49        /// Wrapper type used to implement stream traits.
50        pub struct $struct_name($inner_ty);
51
52        impl $struct_name {
53            /// Create a new wrapper.
54            pub fn new(stream: $inner_ty) -> Self {
55                Self(stream)
56            }
57        }
58
59        impl AsyncRead for $struct_name {
60            fn poll_read(
61                mut self: std::pin::Pin<&mut Self>,
62                cx: &mut std::task::Context<'_>,
63                buf: &mut tokio::io::ReadBuf<'_>,
64            ) -> std::task::Poll<std::io::Result<()>> {
65                std::pin::Pin::new(&mut self.0).poll_read(cx, buf)
66            }
67        }
68
69        impl AsyncWrite for $struct_name {
70            fn poll_write(
71                mut self: std::pin::Pin<&mut Self>,
72                cx: &mut std::task::Context<'_>,
73                buf: &[u8],
74            ) -> std::task::Poll<std::prelude::v1::Result<usize, std::io::Error>> {
75                std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
76            }
77
78            fn poll_flush(
79                mut self: std::pin::Pin<&mut Self>,
80                cx: &mut std::task::Context<'_>,
81            ) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
82                std::pin::Pin::new(&mut self.0).poll_flush(cx)
83            }
84
85            fn poll_shutdown(
86                mut self: std::pin::Pin<&mut Self>,
87                cx: &mut std::task::Context<'_>,
88            ) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
89                std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
90            }
91        }
92    };
93}
94
95gen_stream_impl!(TcpStreamImpl, TcpStream);
96gen_stream_impl!(UdpStreamImpl, UdpStream);
97
98impl StreamSplit for TcpStreamImpl {
99    type ReaderOwned = tokio::net::tcp::OwnedReadHalf;
100    type ReaderRef<'a>
101        = ReadHalf<'a>
102    where
103        Self: 'a;
104    type WriterOwned = tokio::net::tcp::OwnedWriteHalf;
105    type WriterRef<'a>
106        = WriteHalf<'a>
107    where
108        Self: 'a;
109
110    fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
111        self.0.split()
112    }
113
114    fn into_split(self) -> (Self::ReaderOwned, Self::WriterOwned) {
115        self.0.into_split()
116    }
117}
118
119impl StreamSplit for UdpStreamImpl {
120    type ReaderOwned = crate::udp::UdpStreamOwnedReadHalf;
121    type ReaderRef<'a> = UdpStreamReadHalf;
122    type WriterOwned = crate::udp::UdpStreamOwnedWriteHalf;
123    type WriterRef<'a>
124        = UdpStreamWriteHalf<'a>
125    where
126        Self: 'a;
127
128    fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
129        self.0.split()
130    }
131
132    fn into_split(self) -> (Self::ReaderOwned, Self::WriterOwned) {
133        self.0.into_split()
134    }
135}
136
137impl NetworkStream for TcpStreamImpl {}
138impl NetworkStream for UdpStreamImpl {}
139
140/// Provides an abstraction for connect.
141pub trait StreamProvider {
142    /// Stream obtained after connect.
143    type Item: NetworkStream;
144
145    /// Create a stream from a socket address or hostname.
146    fn from_addr<A: ToSocketAddrs + Send>(
147        addr: A,
148    ) -> impl std::future::Future<Output = Result<Self::Item>> + Send;
149}
150
151/// Provider for TCP connections.
152pub struct TcpStreamProvider;
153
154impl StreamProvider for TcpStreamProvider {
155    type Item = TcpStreamImpl;
156
157    async fn from_addr<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
158        Ok(TcpStreamImpl(each_addr(addr, TcpStream::connect).await?))
159    }
160}
161
162/// Provider for UDP connections.
163pub struct UdpStreamProvider;
164
165impl StreamProvider for UdpStreamProvider {
166    type Item = UdpStreamImpl;
167
168    async fn from_addr<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
169        Ok(UdpStreamImpl(UdpStream::connect(addr).await?))
170    }
171}
172
173/// Provides an abstraction for bind.
174pub trait ListenerProvider {
175    /// Listener obtained after bind.
176    type Listener: StreamAccept + 'static;
177
178    /// Bind a listener from address/hostname.
179    fn bind<A: ToSocketAddrs + Send>(
180        addr: A,
181    ) -> impl std::future::Future<Output = Result<Self::Listener>> + Send;
182}
183
184/// Abstractions for listener-provided operations.
185pub trait StreamAccept {
186    /// Stream obtained after accept.
187    type Item: NetworkStream;
188
189    /// Listener waits to get new Stream.
190    fn accept(&self) -> impl std::future::Future<Output = Result<(Self::Item, SocketAddr)>> + Send;
191}
192
193/// Provider for TCP listeners.
194pub struct TcpListenerProvider;
195
196/// TCP listener wrapper.
197pub struct TcpListenerImpl(TcpListener);
198
199impl StreamAccept for TcpListenerImpl {
200    type Item = TcpStreamImpl;
201
202    async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
203        let (stream, addr) = self.0.accept().await?;
204        Ok((TcpStreamImpl::new(stream), addr))
205    }
206}
207
208impl ListenerProvider for TcpListenerProvider {
209    type Listener = TcpListenerImpl;
210
211    async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
212        Ok(TcpListenerImpl(each_addr(addr, TcpListener::bind).await?))
213    }
214}
215
216/// Provider for UDP listeners.
217pub struct UdpListenerProvider;
218
219/// UDP listener wrapper.
220pub struct UdpListenerImpl(UdpListener);
221
222impl StreamAccept for UdpListenerImpl {
223    type Item = UdpStreamImpl;
224
225    async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
226        let (stream, addr) = self.0.accept().await?;
227        Ok((UdpStreamImpl::new(stream), addr))
228    }
229}
230
231impl ListenerProvider for UdpListenerProvider {
232    type Listener = UdpListenerImpl;
233
234    async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
235        Ok(UdpListenerImpl(UdpListener::bind(addr).await?))
236    }
237}
238
239/// How long it takes for TCP to start sending keepalive probe packets when no data is exchanged.
240const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(20);
241/// Time interval between two consecutive keepalive probe packets.
242const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(20);
243/// Enable TCP keepalive on a socket.
244#[cfg(not(target_os = "windows"))]
245pub fn set_tcp_keep_alive<S: AsFd>(stream: &S) -> std::result::Result<(), std::io::Error> {
246    let sock_ref = socket2::SockRef::from(stream);
247    let mut ka = socket2::TcpKeepalive::new();
248    ka = ka.with_time(TCP_KEEPALIVE_TIME);
249    ka = ka.with_interval(TCP_KEEPALIVE_INTERVAL);
250    sock_ref.set_tcp_keepalive(&ka)
251}
252
253/// Enable TCP keepalive on a socket.
254#[cfg(target_os = "windows")]
255pub fn set_tcp_keep_alive<S: AsSocket>(stream: &S) -> std::result::Result<(), std::io::Error> {
256    let sock_ref = socket2::SockRef::from(stream);
257    let mut ka = socket2::TcpKeepalive::new();
258    ka = ka.with_time(TCP_KEEPALIVE_TIME);
259    ka = ka.with_interval(TCP_KEEPALIVE_INTERVAL);
260    sock_ref.set_tcp_keepalive(&ka)
261}
262
263/// Disable Nagle's algorithm to reduce latency for small packets.
264#[cfg(not(target_os = "windows"))]
265pub fn set_tcp_nodelay<S: AsFd>(stream: &S) -> std::result::Result<(), std::io::Error> {
266    let sock_ref = socket2::SockRef::from(stream);
267    sock_ref.set_tcp_nodelay(true)
268}
269
270/// Disable Nagle's algorithm to reduce latency for small packets.
271#[cfg(target_os = "windows")]
272pub fn set_tcp_nodelay<S: AsSocket>(stream: &S) -> std::result::Result<(), std::io::Error> {
273    let sock_ref = socket2::SockRef::from(stream);
274    sock_ref.set_tcp_nodelay(true)
275}
276
277/// Resolve a single socket address from input.
278pub async fn got_one_socket_addr<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
279    let mut iter = addr.to_socket_addrs().await?;
280    iter.next().ok_or_else(|| {
281        std::io::Error::new(
282            std::io::ErrorKind::InvalidInput,
283            "could not resolve to any addresses",
284        )
285    })
286}