trz_gateway_common/
to_async_io.rs

1use std::future::ready;
2use std::io::ErrorKind;
3
4use bytes::Bytes;
5use futures::Sink;
6use futures::SinkExt as _;
7use futures::Stream;
8use futures::StreamExt as _;
9use nameth::NamedType as _;
10use nameth::nameth;
11use tokio::sync::oneshot;
12use tokio_util::io::CopyToBytes;
13use tokio_util::io::SinkWriter;
14use tokio_util::io::StreamReader;
15
16/// Helper to convert
17/// - an object implementing [Stream] + [Sink]
18/// - into an object implementing [tokio::io::AsyncRead] + [tokio::io::AsyncWrite]
19pub trait WebSocketIo {
20    type Message;
21    type Error: std::error::Error + Send + Sync + 'static;
22
23    fn into_data(message: Self::Message) -> Bytes;
24    fn into_messsge(bytes: Bytes) -> Self::Message;
25
26    fn to_async_io(
27        web_socket: impl Stream<Item = Result<Self::Message, Self::Error>>
28        + Sink<Self::Message, Error = Self::Error>,
29    ) -> (
30        impl tokio::io::AsyncRead + tokio::io::AsyncWrite,
31        impl Future<Output = std::io::Result<()>>,
32    )
33    where
34        Self: Sized,
35    {
36        to_async_io_impl::<Self>(web_socket)
37    }
38}
39
40#[nameth]
41#[derive(thiserror::Error, Debug)]
42#[error("[{n}] {0}", n = Self::type_name())]
43struct ReadError<E>(E);
44
45#[nameth]
46#[derive(thiserror::Error, Debug)]
47#[error("[{n}] {0}", n = Self::type_name())]
48struct WriteError<E>(E);
49
50fn to_async_io_impl<IO: WebSocketIo>(
51    web_socket: impl Stream<Item = Result<IO::Message, IO::Error>>
52    + Sink<IO::Message, Error = IO::Error>,
53) -> (
54    impl tokio::io::AsyncRead + tokio::io::AsyncWrite,
55    impl Future<Output = std::io::Result<()>>,
56) {
57    let (error_tx, error_rx) = oneshot::channel();
58    let mut error_tx = Some(error_tx);
59    let (sink, stream) = web_socket.split();
60
61    let reader = {
62        StreamReader::new(stream.map(move |message| {
63            let message = message.map(IO::into_data).map_err(|error: IO::Error| {
64                let error = std::io::Error::new(ErrorKind::ConnectionAborted, ReadError(error));
65                let error_tx = error_tx.take();
66                error_tx.map(|error_tx| error_tx.send(error));
67                return ErrorKind::ConnectionAborted;
68            });
69            return message;
70        }))
71    };
72
73    let writer = {
74        let sink = CopyToBytes::new(sink.with(|data| ready(Ok(IO::into_messsge(data)))))
75            .sink_map_err(|error: IO::Error| {
76                std::io::Error::new(ErrorKind::ConnectionAborted, WriteError(error))
77            });
78        SinkWriter::new(sink)
79    };
80
81    let eos = Box::pin(async {
82        match error_rx.await {
83            // The stream raised an error.
84            Ok(error) => Err(error),
85
86            // The stream was dropped, finished without raising an error.
87            Err(oneshot::error::RecvError { .. }) => Ok(()),
88        }
89    });
90    (tokio::io::join(reader, writer), eos)
91}