transmog_async/
lib.rs

1#![doc = include_str!("./.crate-docs.md")]
2#![forbid(unsafe_code)]
3#![warn(
4    clippy::cargo,
5    missing_docs,
6    // clippy::missing_docs_in_private_items,
7    clippy::pedantic,
8    future_incompatible,
9    rust_2018_idioms,
10)]
11#![allow(
12    clippy::missing_errors_doc, // TODO clippy::missing_errors_doc
13    clippy::option_if_let_else,
14    clippy::module_name_repetitions,
15)]
16
17mod reader;
18mod writer;
19
20use std::{
21    fmt, io,
22    marker::PhantomData,
23    ops::{Deref, DerefMut},
24    pin::Pin,
25    task::{Context, Poll},
26};
27
28use futures_core::Stream;
29use futures_sink::Sink;
30use tokio::io::{AsyncRead, ReadBuf};
31pub use transmog;
32use transmog::Format;
33
34pub use self::{
35    reader::TransmogReader,
36    writer::{AsyncDestination, SyncDestination, TransmogWriter, TransmogWriterFor},
37};
38
39/// Builder helper to specify types without the need of turbofishing.
40pub struct Builder<TReads, TWrites, TStream, TFormat> {
41    stream: TStream,
42    format: TFormat,
43    datatypes: PhantomData<(TReads, TWrites)>,
44}
45
46impl<TStream, TFormat> Builder<(), (), TStream, TFormat> {
47    /// Returns a new stream builder for `stream` and `format`.
48    pub fn new(stream: TStream, format: TFormat) -> Self {
49        Self {
50            stream,
51            format,
52            datatypes: PhantomData,
53        }
54    }
55}
56
57impl<TStream, TFormat> Builder<(), (), TStream, TFormat> {
58    /// Sets `T` as the type for both sending and receiving.
59    pub fn sends_and_receives<T>(self) -> Builder<T, T, TStream, TFormat>
60    where
61        TFormat: Format<'static, T>,
62    {
63        Builder {
64            stream: self.stream,
65            format: self.format,
66            datatypes: PhantomData,
67        }
68    }
69}
70
71impl<TReads, TStream, TFormat> Builder<TReads, (), TStream, TFormat> {
72    /// Sets `T` as the type of data that is written to this stream.
73    pub fn sends<T>(self) -> Builder<TReads, T, TStream, TFormat>
74    where
75        TFormat: Format<'static, T>,
76    {
77        Builder {
78            stream: self.stream,
79            format: self.format,
80            datatypes: PhantomData,
81        }
82    }
83}
84
85impl<TWrites, TStream, TFormat> Builder<(), TWrites, TStream, TFormat> {
86    /// Sets `T` as the type of data that is read from this stream.
87    pub fn receives<T>(self) -> Builder<T, TWrites, TStream, TFormat>
88    where
89        TFormat: Format<'static, T>,
90    {
91        Builder {
92            stream: self.stream,
93            format: self.format,
94            datatypes: PhantomData,
95        }
96    }
97}
98
99impl<TReads, TWrites, TStream, TFormat> Builder<TReads, TWrites, TStream, TFormat>
100where
101    TFormat: Clone,
102{
103    /// Build this stream to include the serialized data's size before each
104    /// serialized value.
105    ///
106    /// This is necessary for compatability with a remote [`TransmogReader`].
107    pub fn for_async(self) -> TransmogStream<TReads, TWrites, TStream, AsyncDestination, TFormat> {
108        TransmogStream::new(self.stream, self.format).for_async()
109    }
110
111    /// Build this stream only send Transmog-encoded values.
112    ///
113    /// This is necessary for compatability with stock Transmog receivers.
114    pub fn for_sync(self) -> TransmogStream<TReads, TWrites, TStream, SyncDestination, TFormat> {
115        TransmogStream::new(self.stream, self.format)
116    }
117}
118
119/// A wrapper around an asynchronous stream that receives and sends bincode-encoded values.
120///
121/// To use, provide a stream that implements both [`AsyncWrite`](tokio::io::AsyncWrite) and [`AsyncRead`], and then use
122/// [`Sink`] to send values and [`Stream`] to receive them.
123///
124/// Note that an `TransmogStream` must be of the type [`AsyncDestination`] in order to be
125/// compatible with an [`TransmogReader`] on the remote end (recall that it requires the
126/// serialized size prefixed to the serialized data). The default is [`SyncDestination`], but these
127/// can be easily toggled between using [`TransmogStream::for_async`].
128#[derive(Debug)]
129pub struct TransmogStream<TReads, TWrites, TStream, TDestination, TFormat> {
130    stream: TransmogReader<
131        InternalTransmogWriter<TStream, TWrites, TDestination, TFormat>,
132        TReads,
133        TFormat,
134    >,
135}
136
137#[doc(hidden)]
138pub struct InternalTransmogWriter<TStream, T, TDestination, TFormat>(
139    TransmogWriter<TStream, T, TDestination, TFormat>,
140);
141
142impl<TStream: fmt::Debug, T, TDestination, TFormat> fmt::Debug
143    for InternalTransmogWriter<TStream, T, TDestination, TFormat>
144{
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        self.get_ref().fmt(f)
147    }
148}
149
150impl<TReads, TWrites, TStream, TDestination, TFormat>
151    TransmogStream<TReads, TWrites, TStream, TDestination, TFormat>
152{
153    /// Gets a reference to the underlying stream.
154    ///
155    /// It is inadvisable to directly read from or write to the underlying stream.
156    pub fn get_ref(&self) -> &TStream {
157        self.stream.get_ref().0.get_ref()
158    }
159
160    /// Gets a mutable reference to the underlying stream.
161    ///
162    /// It is inadvisable to directly read from or write to the underlying stream.
163    pub fn get_mut(&mut self) -> &mut TStream {
164        self.stream.get_mut().0.get_mut()
165    }
166
167    /// Unwraps this `TransmogStream`, returning the underlying stream.
168    ///
169    /// Note that any leftover serialized data that has not yet been sent, or received data that
170    /// has not yet been deserialized, is lost.
171    pub fn into_inner(self) -> (TStream, TFormat) {
172        self.stream.into_inner().0.into_inner()
173    }
174}
175
176impl<TStream, TFormat> TransmogStream<(), (), TStream, SyncDestination, TFormat> {
177    /// Creates a new instance that sends `format`-encoded payloads over `stream`.
178    pub fn build(stream: TStream, format: TFormat) -> Builder<(), (), TStream, TFormat> {
179        Builder::new(stream, format)
180    }
181}
182
183impl<TReads, TWrites, TStream, TFormat>
184    TransmogStream<TReads, TWrites, TStream, SyncDestination, TFormat>
185where
186    TFormat: Clone,
187{
188    /// Creates a new instance that sends `format`-encoded payloads over `stream`.
189    pub fn new(stream: TStream, format: TFormat) -> Self {
190        TransmogStream {
191            stream: TransmogReader::new(
192                InternalTransmogWriter(TransmogWriter::new(stream, format.clone())),
193                format,
194            ),
195        }
196    }
197
198    /// Creates a new instance that sends `format`-encoded payloads over the
199    /// default stream for `TStream`.
200    pub fn default_for(format: TFormat) -> Self
201    where
202        TStream: Default,
203    {
204        Self::new(TStream::default(), format)
205    }
206}
207
208impl<TReads, TWrites, TStream, TDestination, TFormat>
209    TransmogStream<TReads, TWrites, TStream, TDestination, TFormat>
210where
211    TFormat: Clone,
212{
213    /// Make this stream include the serialized data's size before each serialized value.
214    ///
215    /// This is necessary for compatability with a remote [`TransmogReader`].
216    pub fn for_async(self) -> TransmogStream<TReads, TWrites, TStream, AsyncDestination, TFormat> {
217        let (stream, format) = self.into_inner();
218        TransmogStream {
219            stream: TransmogReader::new(
220                InternalTransmogWriter(TransmogWriter::new(stream, format.clone()).for_async()),
221                format,
222            ),
223        }
224    }
225
226    /// Make this stream only send Transmog-encoded values.
227    ///
228    /// This is necessary for compatability with stock Transmog receivers.
229    pub fn for_sync(self) -> TransmogStream<TReads, TWrites, TStream, SyncDestination, TFormat> {
230        let (stream, format) = self.into_inner();
231        TransmogStream::new(stream, format)
232    }
233}
234
235/// A reader of Transmog-encoded data from a [`TcpStream`](tokio::net::TcpStream).
236pub type TransmogTokioTcpReader<'a, TReads, TFormat> =
237    TransmogReader<tokio::net::tcp::ReadHalf<'a>, TReads, TFormat>;
238/// A writer of Transmog-encoded data to a [`TcpStream`](tokio::net::TcpStream).
239pub type TransmogTokioTcpWriter<'a, TWrites, TDestination, TFormat> =
240    TransmogWriter<tokio::net::tcp::WriteHalf<'a>, TWrites, TDestination, TFormat>;
241
242impl<TReads, TWrites, TDestination, TFormat>
243    TransmogStream<TReads, TWrites, tokio::net::TcpStream, TDestination, TFormat>
244where
245    TFormat: Clone,
246{
247    /// Split a TCP-based stream into a read half and a write half.
248    ///
249    /// This is more performant than using a lock-based split like the one provided by `tokio-io`
250    /// or `futures-util` since we know that reads and writes to a `TcpStream` can continue
251    /// concurrently.
252    ///
253    /// Any partially sent or received state is preserved.
254    pub fn tcp_split(
255        &mut self,
256    ) -> (
257        TransmogTokioTcpReader<'_, TReads, TFormat>,
258        TransmogTokioTcpWriter<'_, TWrites, TDestination, TFormat>,
259    ) {
260        // First, steal the reader state so it isn't lost
261        let rbuff = self.stream.buffer.split();
262        // Then, fish out the writer
263        let writer = &mut self.stream.get_mut().0;
264        let format = writer.format().clone();
265        // And steal the writer state so it isn't lost
266        let write_buffer = writer.buffer.split_off(0);
267        let write_buffer_written = writer.written;
268        // Now split the stream
269        let (r, w) = writer.get_mut().split();
270        // Then put the reader back together
271        let mut reader = TransmogReader::new(r, format.clone());
272        reader.buffer = rbuff;
273        // And then the writer
274        let mut writer: TransmogWriter<_, _, TDestination, TFormat> =
275            TransmogWriter::new(w, format).make_for();
276        writer.buffer = write_buffer;
277        writer.written = write_buffer_written;
278        // All good!
279        (reader, writer)
280    }
281}
282
283impl<TStream, T, TDestination, TFormat> AsyncRead
284    for InternalTransmogWriter<TStream, T, TDestination, TFormat>
285where
286    TStream: AsyncRead + Unpin,
287{
288    fn poll_read(
289        self: Pin<&mut Self>,
290        cx: &mut Context<'_>,
291        buf: &mut ReadBuf<'_>,
292    ) -> Poll<Result<(), io::Error>> {
293        Pin::new(self.get_mut().get_mut()).poll_read(cx, buf)
294    }
295}
296
297impl<TStream, T, TDestination, TFormat> Deref
298    for InternalTransmogWriter<TStream, T, TDestination, TFormat>
299{
300    type Target = TransmogWriter<TStream, T, TDestination, TFormat>;
301    fn deref(&self) -> &Self::Target {
302        &self.0
303    }
304}
305impl<TStream, T, TDestination, TFormat> DerefMut
306    for InternalTransmogWriter<TStream, T, TDestination, TFormat>
307{
308    fn deref_mut(&mut self) -> &mut Self::Target {
309        &mut self.0
310    }
311}
312
313impl<TReads, TWrites, TStream, TDestination, TFormat> Stream
314    for TransmogStream<TReads, TWrites, TStream, TDestination, TFormat>
315where
316    TStream: Unpin,
317    TransmogReader<
318        InternalTransmogWriter<TStream, TWrites, TDestination, TFormat>,
319        TReads,
320        TFormat,
321    >: Stream<Item = Result<TReads, TFormat::Error>>,
322    TFormat: Format<'static, TWrites>,
323{
324    type Item = Result<TReads, TFormat::Error>;
325    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
326        Pin::new(&mut self.stream).poll_next(cx)
327    }
328}
329
330impl<TReads, TWrites, TStream, TDestination, TFormat> Sink<TWrites>
331    for TransmogStream<TReads, TWrites, TStream, TDestination, TFormat>
332where
333    TStream: Unpin,
334    TransmogWriter<TStream, TWrites, TDestination, TFormat>: Sink<TWrites, Error = TFormat::Error>,
335    TFormat: Format<'static, TWrites>,
336{
337    type Error = TFormat::Error;
338
339    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
340        Pin::new(&mut **self.stream.get_mut()).poll_ready(cx)
341    }
342
343    fn start_send(mut self: Pin<&mut Self>, item: TWrites) -> Result<(), Self::Error> {
344        Pin::new(&mut **self.stream.get_mut()).start_send(item)
345    }
346
347    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
348        Pin::new(&mut **self.stream.get_mut()).poll_flush(cx)
349    }
350
351    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
352        Pin::new(&mut **self.stream.get_mut()).poll_close(cx)
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use futures::prelude::*;
359    use transmog::OwnedDeserializer;
360    use transmog_bincode::Bincode;
361    use transmog_pot::Pot;
362
363    use super::*;
364
365    async fn it_works<
366        T: std::fmt::Debug + Clone + PartialEq + Send,
367        TFormat: Format<'static, T> + OwnedDeserializer<T> + Clone + 'static,
368    >(
369        format: TFormat,
370        values: &[T],
371    ) {
372        let echo = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
373        let addr = echo.local_addr().unwrap();
374
375        let task_format = format.clone();
376        tokio::spawn(async move {
377            let (stream, _) = echo.accept().await.unwrap();
378            let mut stream = TransmogStream::<T, T, _, _, _>::new(stream, task_format).for_async();
379            let (r, w) = stream.tcp_split();
380            r.forward(w).await.unwrap();
381        });
382
383        let client = tokio::net::TcpStream::connect(&addr).await.unwrap();
384        let mut client = TransmogStream::<T, T, _, _, _>::new(client, format).for_async();
385
386        for value in values {
387            client.send(value.clone()).await.unwrap();
388            assert_eq!(&client.next().await.unwrap().unwrap(), value);
389        }
390
391        drop(client);
392    }
393
394    #[tokio::test]
395    async fn it_works_bincode() {
396        // Test short payloads
397        it_works(Bincode::default(), &[44, 42]).await;
398        // Test a long payload
399        it_works(Bincode::default(), &[vec![0_u8; 1_000_000]]).await;
400    }
401
402    #[tokio::test]
403    async fn it_works_pot() {
404        // Test short payloads
405        it_works(Pot::default(), &[44, 42]).await;
406        // Test a long payload
407        it_works(Pot::default(), &[vec![0_u8; 1_000_000]]).await;
408    }
409
410    #[tokio::test]
411    async fn lots() {
412        let echo = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
413        let addr = echo.local_addr().unwrap();
414
415        tokio::spawn(async move {
416            let (stream, _) = echo.accept().await.unwrap();
417            let mut stream =
418                TransmogStream::<usize, usize, _, _, _>::new(stream, Bincode::default())
419                    .for_async();
420            let (r, w) = stream.tcp_split();
421            r.forward(w).await.unwrap();
422        });
423
424        let n = 81920;
425        let stream = tokio::net::TcpStream::connect(&addr).await.unwrap();
426        let mut c = TransmogStream::new(stream, Bincode::default()).for_async();
427
428        futures::stream::iter(0_usize..n)
429            .map(Ok)
430            .forward(&mut c)
431            .await
432            .unwrap();
433
434        let mut at = 0;
435        while let Some(got) = c.next().await.transpose().unwrap() {
436            assert_eq!(at, got);
437            at += 1;
438        }
439        assert_eq!(at, n);
440    }
441}