1use std::net::SocketAddr;
4
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use tokio::net::tcp::{ReadHalf, WriteHalf};
7use tokio::net::{TcpListener, TcpStream};
8
9use crate::udp::UdpListener;
10
11use super::addr::{each_addr, ToSocketAddrs};
12use super::udp::{UdpStream, UdpStreamReadHalf, UdpStreamWriteHalf};
13
14type Result<T, E = std::io::Error> = std::result::Result<T, E>;
15
16pub trait NetworkStream: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static {
18 type ReaderRef<'a>: AsyncReadExt + Send + Unpin + Send
20 where
21 Self: 'a;
22 type WriterRef<'a>: AsyncWriteExt + Send + Unpin + Send
24 where
25 Self: 'a;
26
27 type InnerStream: AsyncReadExt + AsyncWriteExt + Unpin + Send;
29
30 fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>);
32
33 fn into_inner_stream(self) -> Self::InnerStream;
35
36 fn local_addr(&self) -> Result<SocketAddr>;
38
39 fn peer_addr(&self) -> Result<SocketAddr>;
41}
42
43macro_rules! gen_stream_impl {
44 ($struct_name:ident, $inner_ty:ty,$doc_string:literal) => {
45 #[doc = $doc_string]
46 pub struct $struct_name($inner_ty);
47
48 impl $struct_name {
49 pub fn new(stream: $inner_ty) -> Self {
51 Self(stream)
52 }
53 }
54
55 impl AsyncRead for $struct_name {
56 fn poll_read(
57 mut self: std::pin::Pin<&mut Self>,
58 cx: &mut std::task::Context<'_>,
59 buf: &mut tokio::io::ReadBuf<'_>,
60 ) -> std::task::Poll<std::io::Result<()>> {
61 std::pin::Pin::new(&mut self.0).poll_read(cx, buf)
62 }
63 }
64
65 impl AsyncWrite for $struct_name {
66 fn poll_write(
67 mut self: std::pin::Pin<&mut Self>,
68 cx: &mut std::task::Context<'_>,
69 buf: &[u8],
70 ) -> std::task::Poll<std::prelude::v1::Result<usize, std::io::Error>> {
71 std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
72 }
73
74 fn poll_flush(
75 mut self: std::pin::Pin<&mut Self>,
76 cx: &mut std::task::Context<'_>,
77 ) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
78 std::pin::Pin::new(&mut self.0).poll_flush(cx)
79 }
80
81 fn poll_shutdown(
82 mut self: std::pin::Pin<&mut Self>,
83 cx: &mut std::task::Context<'_>,
84 ) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
85 std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
86 }
87 }
88 };
89}
90
91gen_stream_impl!(
92 TcpStreamImpl,
93 TcpStream,
94 "Implementing NetworkStream for TcpStream"
95);
96
97gen_stream_impl!(
98 UdpStreamImpl,
99 UdpStream,
100 "Implementing NetworkStream for UdpStream"
101);
102
103impl NetworkStream for TcpStreamImpl {
104 type ReaderRef<'a> = ReadHalf<'a>
105 where
106 Self: 'a;
107
108 type WriterRef<'a> = WriteHalf<'a>
109 where
110 Self: 'a;
111
112 type InnerStream = TcpStream;
113
114 fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
115 self.0.split()
116 }
117
118 fn into_inner_stream(self) -> Self::InnerStream {
119 self.0
120 }
121
122 fn local_addr(&self) -> Result<SocketAddr> {
123 self.0.local_addr()
124 }
125
126 fn peer_addr(&self) -> Result<SocketAddr> {
127 self.0.peer_addr()
128 }
129}
130
131impl NetworkStream for UdpStreamImpl {
132 type ReaderRef<'a> = UdpStreamReadHalf<'static>;
133
134 type WriterRef<'a> = UdpStreamWriteHalf<'a>
135 where
136 Self: 'a;
137
138 type InnerStream = UdpStream;
139
140 fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
141 self.0.split()
142 }
143
144 fn into_inner_stream(self) -> Self::InnerStream {
145 self.0
146 }
147
148 fn local_addr(&self) -> Result<SocketAddr> {
149 self.0.local_addr()
150 }
151
152 fn peer_addr(&self) -> Result<SocketAddr> {
153 self.0.peer_addr()
154 }
155}
156
157pub trait StreamProvider {
159 type Item: NetworkStream;
161
162 fn connect<A: ToSocketAddrs + Send>(
166 addr: A,
167 ) -> impl std::future::Future<Output = Result<Self::Item>> + Send;
168}
169
170pub struct TcpStreamProvider;
172
173impl StreamProvider for TcpStreamProvider {
174 type Item = TcpStreamImpl;
175
176 async fn connect<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
177 Ok(TcpStreamImpl(each_addr(addr, TcpStream::connect).await?))
178 }
179}
180
181pub struct UdpStreamProvider;
183
184impl StreamProvider for UdpStreamProvider {
185 type Item = UdpStreamImpl;
186
187 async fn connect<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
188 Ok(UdpStreamImpl(UdpStream::connect(addr).await?))
189 }
190}
191
192pub trait ListenerProvider {
194 type Listener: StreamAccept + 'static;
196
197 fn bind<A: ToSocketAddrs + Send>(
201 addr: A,
202 ) -> impl std::future::Future<Output = Result<Self::Listener>> + Send;
203}
204
205pub trait StreamAccept {
207 type Item: NetworkStream;
209
210 fn accept(&self) -> impl std::future::Future<Output = Result<(Self::Item, SocketAddr)>> + Send;
212}
213
214pub struct TcpListenerProvider;
216
217impl ListenerProvider for TcpListenerProvider {
218 type Listener = TcpListenerImpl;
219
220 async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
221 Ok(TcpListenerImpl(each_addr(addr, TcpListener::bind).await?))
222 }
223}
224
225pub struct TcpListenerImpl(TcpListener);
227
228impl StreamAccept for TcpListenerImpl {
229 type Item = TcpStreamImpl;
230
231 async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
232 let (stream, addr) = self.0.accept().await?;
233 Ok((TcpStreamImpl::new(stream), addr))
234 }
235}
236
237pub struct UdpListenerProvider;
239
240impl ListenerProvider for UdpListenerProvider {
241 type Listener = UdpListenerImpl;
242
243 async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
244 Ok(UdpListenerImpl(UdpListener::bind(addr).await?))
245 }
246}
247
248pub struct UdpListenerImpl(UdpListener);
250
251impl StreamAccept for UdpListenerImpl {
252 type Item = UdpStreamImpl;
253
254 async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
255 let (stream, addr) = self.0.accept().await?;
256 Ok((UdpStreamImpl::new(stream), addr))
257 }
258}