1#![deny(missing_docs)]
10
11use futures::{compat::*, prelude::*, ready};
12use pin_utils::unsafe_pinned;
13use serde::{Deserialize, Serialize};
14use std::{
15 error::Error,
16 io,
17 marker::PhantomData,
18 net::SocketAddr,
19 pin::Pin,
20 task::{Context, Poll},
21};
22use tokio::codec::{length_delimited::LengthDelimitedCodec, Framed};
23use tokio_io::{AsyncRead, AsyncWrite};
24use tokio_serde_json::*;
25use tokio_tcp::{TcpListener, TcpStream};
26
27pub struct Transport<S: AsyncWrite, Item, SinkItem> {
29 inner: Compat01As03Sink<
30 ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
31 SinkItem,
32 >,
33}
34
35impl<S: AsyncWrite, Item, SinkItem> Transport<S, Item, SinkItem> {
36 unsafe_pinned!(
37 inner:
38 Compat01As03Sink<
39 ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
40 SinkItem,
41 >
42 );
43}
44
45impl<S, Item, SinkItem> Stream for Transport<S, Item, SinkItem>
46where
47 S: AsyncWrite + AsyncRead,
48 Item: for<'a> Deserialize<'a>,
49{
50 type Item = io::Result<Item>;
51
52 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
53 match self.inner().poll_next(cx) {
54 Poll::Pending => Poll::Pending,
55 Poll::Ready(None) => Poll::Ready(None),
56 Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))),
57 Poll::Ready(Some(Err(e))) => {
58 Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
59 }
60 }
61 }
62}
63
64impl<S, Item, SinkItem> Sink<SinkItem> for Transport<S, Item, SinkItem>
65where
66 S: AsyncWrite,
67 SinkItem: Serialize,
68{
69 type Error = io::Error;
70
71 fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
72 self.inner()
73 .start_send(item)
74 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
75 }
76
77 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
78 convert(self.inner().poll_ready(cx))
79 }
80
81 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
82 convert(self.inner().poll_flush(cx))
83 }
84
85 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
86 convert(self.inner().poll_close(cx))
87 }
88}
89
90fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
91 poll: Poll<Result<(), E>>,
92) -> Poll<io::Result<()>> {
93 match poll {
94 Poll::Pending => Poll::Pending,
95 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
96 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
97 }
98}
99
100impl<Item, SinkItem> Transport<TcpStream, Item, SinkItem> {
101 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
103 self.inner
104 .get_ref()
105 .get_ref()
106 .get_ref()
107 .get_ref()
108 .peer_addr()
109 }
110
111 pub fn local_addr(&self) -> io::Result<SocketAddr> {
113 self.inner
114 .get_ref()
115 .get_ref()
116 .get_ref()
117 .get_ref()
118 .local_addr()
119 }
120}
121
122pub fn new<Item, SinkItem>(io: TcpStream) -> Transport<TcpStream, Item, SinkItem>
124where
125 Item: for<'de> Deserialize<'de>,
126 SinkItem: Serialize,
127{
128 Transport::from(io)
129}
130
131impl<S: AsyncWrite + AsyncRead, Item: serde::de::DeserializeOwned, SinkItem: Serialize> From<S>
132 for Transport<S, Item, SinkItem>
133{
134 fn from(inner: S) -> Self {
135 Transport {
136 inner: Compat01As03Sink::new(ReadJson::new(WriteJson::new(Framed::new(
137 inner,
138 LengthDelimitedCodec::new(),
139 )))),
140 }
141 }
142}
143
144pub async fn connect<Item, SinkItem>(
146 addr: &SocketAddr,
147) -> io::Result<Transport<TcpStream, Item, SinkItem>>
148where
149 Item: for<'de> Deserialize<'de>,
150 SinkItem: Serialize,
151{
152 Ok(new(TcpStream::connect(addr).compat().await?))
153}
154
155pub fn listen<Item, SinkItem>(addr: &SocketAddr) -> io::Result<Incoming<Item, SinkItem>>
157where
158 Item: for<'de> Deserialize<'de>,
159 SinkItem: Serialize,
160{
161 let listener = TcpListener::bind(addr)?;
162 let local_addr = listener.local_addr()?;
163 let incoming = listener.incoming().compat();
164 Ok(Incoming {
165 incoming,
166 local_addr,
167 ghost: PhantomData,
168 })
169}
170
171#[derive(Debug)]
173pub struct Incoming<Item, SinkItem> {
174 incoming: Compat01As03<tokio_tcp::Incoming>,
175 local_addr: SocketAddr,
176 ghost: PhantomData<(Item, SinkItem)>,
177}
178
179impl<Item, SinkItem> Incoming<Item, SinkItem> {
180 unsafe_pinned!(incoming: Compat01As03<tokio_tcp::Incoming>);
181
182 pub fn local_addr(&self) -> SocketAddr {
184 self.local_addr
185 }
186}
187
188impl<Item, SinkItem> Stream for Incoming<Item, SinkItem>
189where
190 Item: for<'a> Deserialize<'a>,
191 SinkItem: Serialize,
192{
193 type Item = io::Result<Transport<TcpStream, Item, SinkItem>>;
194
195 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
196 let next = ready!(self.incoming().poll_next(cx)?);
197 Poll::Ready(next.map(|conn| Ok(new(conn))))
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::Transport;
204 use assert_matches::assert_matches;
205 use futures::{Sink, Stream};
206 use futures_test::task::noop_waker_ref;
207 use pin_utils::pin_mut;
208 use std::{
209 io::Cursor,
210 task::{Context, Poll},
211 };
212
213 fn ctx() -> Context<'static> {
214 Context::from_waker(&noop_waker_ref())
215 }
216
217 #[test]
218 fn test_stream() {
219 let reader = *b"\x00\x00\x00\x18\"Test one, check check.\"";
220 let reader: Box<[u8]> = Box::new(reader);
221 let transport = Transport::<_, String, String>::from(Cursor::new(reader));
222 pin_mut!(transport);
223
224 assert_matches!(
225 transport.poll_next(&mut ctx()),
226 Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
227 }
228
229 #[test]
230 fn test_sink() {
231 let writer: &mut [u8] = &mut [0; 28];
232 let transport = Transport::<_, String, String>::from(Cursor::new(&mut *writer));
233 pin_mut!(transport);
234
235 assert_matches!(
236 transport.as_mut().poll_ready(&mut ctx()),
237 Poll::Ready(Ok(()))
238 );
239 assert_matches!(
240 transport
241 .as_mut()
242 .start_send("Test one, check check.".into()),
243 Ok(())
244 );
245 assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(())));
246 assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\"");
247 }
248}