trz_gateway_common/
to_async_io.rs1use std::future::ready;
2use std::io::ErrorKind;
3
4use bytes::Bytes;
5use futures::Sink;
6use futures::SinkExt as _;
7use futures::Stream;
8use futures::StreamExt as _;
9use nameth::NamedType as _;
10use nameth::nameth;
11use tokio::sync::oneshot;
12use tokio_util::io::CopyToBytes;
13use tokio_util::io::SinkWriter;
14use tokio_util::io::StreamReader;
15
16pub trait WebSocketIo {
20 type Message;
21 type Error: std::error::Error + Send + Sync + 'static;
22
23 fn into_data(message: Self::Message) -> Bytes;
24 fn into_messsge(bytes: Bytes) -> Self::Message;
25
26 fn to_async_io(
27 web_socket: impl Stream<Item = Result<Self::Message, Self::Error>>
28 + Sink<Self::Message, Error = Self::Error>,
29 ) -> (
30 impl tokio::io::AsyncRead + tokio::io::AsyncWrite,
31 impl Future<Output = std::io::Result<()>>,
32 )
33 where
34 Self: Sized,
35 {
36 to_async_io_impl::<Self>(web_socket)
37 }
38}
39
40#[nameth]
41#[derive(thiserror::Error, Debug)]
42#[error("[{n}] {0}", n = Self::type_name())]
43struct ReadError<E>(E);
44
45#[nameth]
46#[derive(thiserror::Error, Debug)]
47#[error("[{n}] {0}", n = Self::type_name())]
48struct WriteError<E>(E);
49
50fn to_async_io_impl<IO: WebSocketIo>(
51 web_socket: impl Stream<Item = Result<IO::Message, IO::Error>>
52 + Sink<IO::Message, Error = IO::Error>,
53) -> (
54 impl tokio::io::AsyncRead + tokio::io::AsyncWrite,
55 impl Future<Output = std::io::Result<()>>,
56) {
57 let (error_tx, error_rx) = oneshot::channel();
58 let mut error_tx = Some(error_tx);
59 let (sink, stream) = web_socket.split();
60
61 let reader = {
62 StreamReader::new(stream.map(move |message| {
63 let message = message.map(IO::into_data).map_err(|error: IO::Error| {
64 let error = std::io::Error::new(ErrorKind::ConnectionAborted, ReadError(error));
65 let error_tx = error_tx.take();
66 error_tx.map(|error_tx| error_tx.send(error));
67 return ErrorKind::ConnectionAborted;
68 });
69 return message;
70 }))
71 };
72
73 let writer = {
74 let sink = CopyToBytes::new(sink.with(|data| ready(Ok(IO::into_messsge(data)))))
75 .sink_map_err(|error: IO::Error| {
76 std::io::Error::new(ErrorKind::ConnectionAborted, WriteError(error))
77 });
78 SinkWriter::new(sink)
79 };
80
81 let eos = Box::pin(async {
82 match error_rx.await {
83 Ok(error) => Err(error),
85
86 Err(oneshot::error::RecvError { .. }) => Ok(()),
88 }
89 });
90 (tokio::io::join(reader, writer), eos)
91}