1use std::net::SocketAddr;
4#[cfg(not(target_os = "windows"))]
5use std::os::fd::AsFd;
6#[cfg(target_os = "windows")]
7use std::os::windows::io::AsSocket;
8use std::time::Duration;
9
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11use tokio::net::tcp::{ReadHalf, WriteHalf};
12use tokio::net::{TcpListener, TcpStream};
13
14use crate::addr::{each_addr, ToSocketAddrs};
15use crate::udp::{UdpListener, UdpStream, UdpStreamReadHalf, UdpStreamWriteHalf};
16
17type Result<T, E = std::io::Error> = std::result::Result<T, E>;
18
19pub trait StreamSplit {
21 type ReaderRef<'a>: AsyncReadExt + Send + Unpin
23 where
24 Self: 'a;
25 type WriterRef<'a>: AsyncWriteExt + Send + Unpin
27 where
28 Self: 'a;
29 type ReaderOwned: AsyncReadExt + Send + Unpin + 'static;
31 type WriterOwned: AsyncWriteExt + Send + Unpin + 'static;
33
34 fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>);
36
37 fn into_split(self) -> (Self::ReaderOwned, Self::WriterOwned);
39}
40
41pub trait NetworkStream:
43 StreamSplit + AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static
44{
45}
46
47macro_rules! gen_stream_impl {
48 ($struct_name:ident, $inner_ty:ty) => {
49 pub struct $struct_name($inner_ty);
51
52 impl $struct_name {
53 pub fn new(stream: $inner_ty) -> Self {
55 Self(stream)
56 }
57 }
58
59 impl AsyncRead for $struct_name {
60 fn poll_read(
61 mut self: std::pin::Pin<&mut Self>,
62 cx: &mut std::task::Context<'_>,
63 buf: &mut tokio::io::ReadBuf<'_>,
64 ) -> std::task::Poll<std::io::Result<()>> {
65 std::pin::Pin::new(&mut self.0).poll_read(cx, buf)
66 }
67 }
68
69 impl AsyncWrite for $struct_name {
70 fn poll_write(
71 mut self: std::pin::Pin<&mut Self>,
72 cx: &mut std::task::Context<'_>,
73 buf: &[u8],
74 ) -> std::task::Poll<std::prelude::v1::Result<usize, std::io::Error>> {
75 std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
76 }
77
78 fn poll_flush(
79 mut self: std::pin::Pin<&mut Self>,
80 cx: &mut std::task::Context<'_>,
81 ) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
82 std::pin::Pin::new(&mut self.0).poll_flush(cx)
83 }
84
85 fn poll_shutdown(
86 mut self: std::pin::Pin<&mut Self>,
87 cx: &mut std::task::Context<'_>,
88 ) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
89 std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
90 }
91 }
92 };
93}
94
95gen_stream_impl!(TcpStreamImpl, TcpStream);
96gen_stream_impl!(UdpStreamImpl, UdpStream);
97
98impl StreamSplit for TcpStreamImpl {
99 type ReaderOwned = tokio::net::tcp::OwnedReadHalf;
100 type ReaderRef<'a>
101 = ReadHalf<'a>
102 where
103 Self: 'a;
104 type WriterOwned = tokio::net::tcp::OwnedWriteHalf;
105 type WriterRef<'a>
106 = WriteHalf<'a>
107 where
108 Self: 'a;
109
110 fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
111 self.0.split()
112 }
113
114 fn into_split(self) -> (Self::ReaderOwned, Self::WriterOwned) {
115 self.0.into_split()
116 }
117}
118
119impl StreamSplit for UdpStreamImpl {
120 type ReaderOwned = crate::udp::UdpStreamOwnedReadHalf;
121 type ReaderRef<'a> = UdpStreamReadHalf;
122 type WriterOwned = crate::udp::UdpStreamOwnedWriteHalf;
123 type WriterRef<'a>
124 = UdpStreamWriteHalf<'a>
125 where
126 Self: 'a;
127
128 fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
129 self.0.split()
130 }
131
132 fn into_split(self) -> (Self::ReaderOwned, Self::WriterOwned) {
133 self.0.into_split()
134 }
135}
136
137impl NetworkStream for TcpStreamImpl {}
138impl NetworkStream for UdpStreamImpl {}
139
140pub trait StreamProvider {
142 type Item: NetworkStream;
144
145 fn from_addr<A: ToSocketAddrs + Send>(
147 addr: A,
148 ) -> impl std::future::Future<Output = Result<Self::Item>> + Send;
149}
150
151pub struct TcpStreamProvider;
153
154impl StreamProvider for TcpStreamProvider {
155 type Item = TcpStreamImpl;
156
157 async fn from_addr<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
158 Ok(TcpStreamImpl(each_addr(addr, TcpStream::connect).await?))
159 }
160}
161
162pub struct UdpStreamProvider;
164
165impl StreamProvider for UdpStreamProvider {
166 type Item = UdpStreamImpl;
167
168 async fn from_addr<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
169 Ok(UdpStreamImpl(UdpStream::connect(addr).await?))
170 }
171}
172
173pub trait ListenerProvider {
175 type Listener: StreamAccept + 'static;
177
178 fn bind<A: ToSocketAddrs + Send>(
180 addr: A,
181 ) -> impl std::future::Future<Output = Result<Self::Listener>> + Send;
182}
183
184pub trait StreamAccept {
186 type Item: NetworkStream;
188
189 fn accept(&self) -> impl std::future::Future<Output = Result<(Self::Item, SocketAddr)>> + Send;
191}
192
193pub struct TcpListenerProvider;
195
196pub struct TcpListenerImpl(TcpListener);
198
199impl StreamAccept for TcpListenerImpl {
200 type Item = TcpStreamImpl;
201
202 async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
203 let (stream, addr) = self.0.accept().await?;
204 Ok((TcpStreamImpl::new(stream), addr))
205 }
206}
207
208impl ListenerProvider for TcpListenerProvider {
209 type Listener = TcpListenerImpl;
210
211 async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
212 Ok(TcpListenerImpl(each_addr(addr, TcpListener::bind).await?))
213 }
214}
215
216pub struct UdpListenerProvider;
218
219pub struct UdpListenerImpl(UdpListener);
221
222impl StreamAccept for UdpListenerImpl {
223 type Item = UdpStreamImpl;
224
225 async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
226 let (stream, addr) = self.0.accept().await?;
227 Ok((UdpStreamImpl::new(stream), addr))
228 }
229}
230
231impl ListenerProvider for UdpListenerProvider {
232 type Listener = UdpListenerImpl;
233
234 async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
235 Ok(UdpListenerImpl(UdpListener::bind(addr).await?))
236 }
237}
238
239const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(20);
241const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(20);
243#[cfg(not(target_os = "windows"))]
245pub fn set_tcp_keep_alive<S: AsFd>(stream: &S) -> std::result::Result<(), std::io::Error> {
246 let sock_ref = socket2::SockRef::from(stream);
247 let mut ka = socket2::TcpKeepalive::new();
248 ka = ka.with_time(TCP_KEEPALIVE_TIME);
249 ka = ka.with_interval(TCP_KEEPALIVE_INTERVAL);
250 sock_ref.set_tcp_keepalive(&ka)
251}
252
253#[cfg(target_os = "windows")]
255pub fn set_tcp_keep_alive<S: AsSocket>(stream: &S) -> std::result::Result<(), std::io::Error> {
256 let sock_ref = socket2::SockRef::from(stream);
257 let mut ka = socket2::TcpKeepalive::new();
258 ka = ka.with_time(TCP_KEEPALIVE_TIME);
259 ka = ka.with_interval(TCP_KEEPALIVE_INTERVAL);
260 sock_ref.set_tcp_keepalive(&ka)
261}
262
263#[cfg(not(target_os = "windows"))]
265pub fn set_tcp_nodelay<S: AsFd>(stream: &S) -> std::result::Result<(), std::io::Error> {
266 let sock_ref = socket2::SockRef::from(stream);
267 sock_ref.set_tcp_nodelay(true)
268}
269
270#[cfg(target_os = "windows")]
272pub fn set_tcp_nodelay<S: AsSocket>(stream: &S) -> std::result::Result<(), std::io::Error> {
273 let sock_ref = socket2::SockRef::from(stream);
274 sock_ref.set_tcp_nodelay(true)
275}
276
277pub async fn got_one_socket_addr<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
279 let mut iter = addr.to_socket_addrs().await?;
280 iter.next().ok_or_else(|| {
281 std::io::Error::new(
282 std::io::ErrorKind::InvalidInput,
283 "could not resolve to any addresses",
284 )
285 })
286}