tarpc_json_transport/
lib.rs

1// Copyright 2019 Google LLC
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT.
6
7//! A TCP [`Transport`] that serializes as JSON.
8
9#![deny(missing_docs)]
10
11use futures::{compat::*, prelude::*, ready};
12use pin_utils::unsafe_pinned;
13use serde::{Deserialize, Serialize};
14use std::{
15    error::Error,
16    io,
17    marker::PhantomData,
18    net::SocketAddr,
19    pin::Pin,
20    task::{Context, Poll},
21};
22use tokio::codec::{length_delimited::LengthDelimitedCodec, Framed};
23use tokio_io::{AsyncRead, AsyncWrite};
24use tokio_serde_json::*;
25use tokio_tcp::{TcpListener, TcpStream};
26
27/// A transport that serializes to, and deserializes from, a [`TcpStream`].
28pub struct Transport<S: AsyncWrite, Item, SinkItem> {
29    inner: Compat01As03Sink<
30        ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
31        SinkItem,
32    >,
33}
34
35impl<S: AsyncWrite, Item, SinkItem> Transport<S, Item, SinkItem> {
36    unsafe_pinned!(
37        inner:
38            Compat01As03Sink<
39                ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
40                SinkItem,
41            >
42    );
43}
44
45impl<S, Item, SinkItem> Stream for Transport<S, Item, SinkItem>
46where
47    S: AsyncWrite + AsyncRead,
48    Item: for<'a> Deserialize<'a>,
49{
50    type Item = io::Result<Item>;
51
52    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
53        match self.inner().poll_next(cx) {
54            Poll::Pending => Poll::Pending,
55            Poll::Ready(None) => Poll::Ready(None),
56            Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))),
57            Poll::Ready(Some(Err(e))) => {
58                Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
59            }
60        }
61    }
62}
63
64impl<S, Item, SinkItem> Sink<SinkItem> for Transport<S, Item, SinkItem>
65where
66    S: AsyncWrite,
67    SinkItem: Serialize,
68{
69    type Error = io::Error;
70
71    fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
72        self.inner()
73            .start_send(item)
74            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
75    }
76
77    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
78        convert(self.inner().poll_ready(cx))
79    }
80
81    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
82        convert(self.inner().poll_flush(cx))
83    }
84
85    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
86        convert(self.inner().poll_close(cx))
87    }
88}
89
90fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
91    poll: Poll<Result<(), E>>,
92) -> Poll<io::Result<()>> {
93    match poll {
94        Poll::Pending => Poll::Pending,
95        Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
96        Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
97    }
98}
99
100impl<Item, SinkItem> Transport<TcpStream, Item, SinkItem> {
101    /// Returns the peer address of the underlying TcpStream.
102    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
103        self.inner
104            .get_ref()
105            .get_ref()
106            .get_ref()
107            .get_ref()
108            .peer_addr()
109    }
110
111    /// Returns the local address of the underlying TcpStream.
112    pub fn local_addr(&self) -> io::Result<SocketAddr> {
113        self.inner
114            .get_ref()
115            .get_ref()
116            .get_ref()
117            .get_ref()
118            .local_addr()
119    }
120}
121
122/// Returns a new JSON transport that reads from and writes to `io`.
123pub fn new<Item, SinkItem>(io: TcpStream) -> Transport<TcpStream, Item, SinkItem>
124where
125    Item: for<'de> Deserialize<'de>,
126    SinkItem: Serialize,
127{
128    Transport::from(io)
129}
130
131impl<S: AsyncWrite + AsyncRead, Item: serde::de::DeserializeOwned, SinkItem: Serialize> From<S>
132    for Transport<S, Item, SinkItem>
133{
134    fn from(inner: S) -> Self {
135        Transport {
136            inner: Compat01As03Sink::new(ReadJson::new(WriteJson::new(Framed::new(
137                inner,
138                LengthDelimitedCodec::new(),
139            )))),
140        }
141    }
142}
143
144/// Connects to `addr`, wrapping the connection in a JSON transport.
145pub async fn connect<Item, SinkItem>(
146    addr: &SocketAddr,
147) -> io::Result<Transport<TcpStream, Item, SinkItem>>
148where
149    Item: for<'de> Deserialize<'de>,
150    SinkItem: Serialize,
151{
152    Ok(new(TcpStream::connect(addr).compat().await?))
153}
154
155/// Listens on `addr`, wrapping accepted connections in JSON transports.
156pub fn listen<Item, SinkItem>(addr: &SocketAddr) -> io::Result<Incoming<Item, SinkItem>>
157where
158    Item: for<'de> Deserialize<'de>,
159    SinkItem: Serialize,
160{
161    let listener = TcpListener::bind(addr)?;
162    let local_addr = listener.local_addr()?;
163    let incoming = listener.incoming().compat();
164    Ok(Incoming {
165        incoming,
166        local_addr,
167        ghost: PhantomData,
168    })
169}
170
171/// A [`TcpListener`] that wraps connections in JSON transports.
172#[derive(Debug)]
173pub struct Incoming<Item, SinkItem> {
174    incoming: Compat01As03<tokio_tcp::Incoming>,
175    local_addr: SocketAddr,
176    ghost: PhantomData<(Item, SinkItem)>,
177}
178
179impl<Item, SinkItem> Incoming<Item, SinkItem> {
180    unsafe_pinned!(incoming: Compat01As03<tokio_tcp::Incoming>);
181
182    /// Returns the address being listened on.
183    pub fn local_addr(&self) -> SocketAddr {
184        self.local_addr
185    }
186}
187
188impl<Item, SinkItem> Stream for Incoming<Item, SinkItem>
189where
190    Item: for<'a> Deserialize<'a>,
191    SinkItem: Serialize,
192{
193    type Item = io::Result<Transport<TcpStream, Item, SinkItem>>;
194
195    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
196        let next = ready!(self.incoming().poll_next(cx)?);
197        Poll::Ready(next.map(|conn| Ok(new(conn))))
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::Transport;
204    use assert_matches::assert_matches;
205    use futures::{Sink, Stream};
206    use futures_test::task::noop_waker_ref;
207    use pin_utils::pin_mut;
208    use std::{
209        io::Cursor,
210        task::{Context, Poll},
211    };
212
213    fn ctx() -> Context<'static> {
214        Context::from_waker(&noop_waker_ref())
215    }
216
217    #[test]
218    fn test_stream() {
219        let reader = *b"\x00\x00\x00\x18\"Test one, check check.\"";
220        let reader: Box<[u8]> = Box::new(reader);
221        let transport = Transport::<_, String, String>::from(Cursor::new(reader));
222        pin_mut!(transport);
223
224        assert_matches!(
225            transport.poll_next(&mut ctx()),
226            Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
227    }
228
229    #[test]
230    fn test_sink() {
231        let writer: &mut [u8] = &mut [0; 28];
232        let transport = Transport::<_, String, String>::from(Cursor::new(&mut *writer));
233        pin_mut!(transport);
234
235        assert_matches!(
236            transport.as_mut().poll_ready(&mut ctx()),
237            Poll::Ready(Ok(()))
238        );
239        assert_matches!(
240            transport
241                .as_mut()
242                .start_send("Test one, check check.".into()),
243            Ok(())
244        );
245        assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(())));
246        assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\"");
247    }
248}