Skip to main content

tfserver/structures/
transport.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use async_tungstenite::{ByteReader, ByteWriter};
5use futures_util::{StreamExt};
6use pin_project::pin_project;
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9#[cfg(not(target_arch = "wasm32"))]
10use tokio::net::TcpStream;
11#[cfg(not(target_arch = "wasm32"))]
12use tokio_rustls::{client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream};
13use tokio_tungstenite::{accept_async, connect_async};
14
15
16pub struct Transport {
17    inner: Box<dyn AsyncReadWrite>,
18}
19
20pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + 'static {
21    #[cfg(not(target_arch = "wasm32"))]
22    fn is_send_sync(&self) where Self: Send + Sync {}
23}
24
25impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static> AsyncReadWrite for T {}
26
27
28
29#[pin_project]
30pub struct WsStreamCompat<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> {
31    #[pin]
32    reader: R,
33    #[pin]
34    writer: W,
35}
36
37impl<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> AsyncRead
38for WsStreamCompat<R, W>
39{
40    fn poll_read(
41        self: Pin<&mut Self>,
42        cx: &mut Context<'_>,
43        buf: &mut ReadBuf<'_>,
44    ) -> Poll<io::Result<()>> {
45        let unfilled = buf.initialize_unfilled();
46        match self.project().reader.poll_read(cx, unfilled) {
47            Poll::Ready(Ok(n)) => {
48                buf.advance(n);
49                Poll::Ready(Ok(()))
50            }
51            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
52            Poll::Pending => Poll::Pending,
53        }
54    }
55}
56
57impl<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> AsyncWrite
58for WsStreamCompat<R, W>
59{
60    fn poll_write(
61        self: Pin<&mut Self>,
62        cx: &mut Context<'_>,
63        buf: &[u8],
64    ) -> Poll<io::Result<usize>> {
65        self.project().writer.poll_write(cx, buf)
66    }
67
68    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
69        self.project().writer.poll_flush(cx)
70    }
71
72    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
73        self.project().writer.poll_close(cx)
74    }
75}
76
77impl Transport {
78    #[cfg(not(target_arch = "wasm32"))]
79    pub fn plain(stream: TcpStream) -> Self {
80        Self { inner: Box::new(stream) }
81    }
82    
83
84    #[cfg(not(target_arch = "wasm32"))]
85    pub fn tls_server(stream: ServerTlsStream<TcpStream>) -> Self {
86        Self { inner: Box::new(stream) }
87    }
88
89    #[cfg(not(target_arch = "wasm32"))]
90    pub fn tls_client(stream: ClientTlsStream<TcpStream>) -> Self {
91        Self { inner: Box::new(stream) }
92    }
93
94    /// On WASM: connect via WebSocket, returns a Transport backed by ws_stream_wasm.
95    /// On native: not available — use plain/tls_client/tls_server + a WS proxy if needed.
96    #[cfg(target_arch = "wasm32")]
97    pub async fn connect(url: &str) -> io::Result<Self> {
98        let (_meta, ws_stream) = WsMeta::connect(url, None)
99            .await
100            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
101
102        let (reader, writer) = ws_stream.into_io().split();
103        Ok(Self {
104            inner: Box::new(WsStreamCompat { reader, writer }),
105        })
106    }
107
108    #[cfg(not(target_arch = "wasm32"))]
109    pub async fn connect(url: &str) -> io::Result<Self> {
110        let (ws_stream, _response) = connect_async(url)
111            .await
112            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()))?;
113
114        let (write, read) = ws_stream.split();
115        let reader = ByteReader::new(read);
116        let writer = ByteWriter::new(write);
117
118        Ok(Self {
119            inner: Box::new(WsStreamCompat { reader, writer }),
120        })
121    }
122
123    #[cfg(not(target_arch = "wasm32"))]
124    pub async fn accept_websocket(stream: Transport) -> io::Result<Self> {
125        let ws_stream = accept_async(stream)
126            .await
127            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
128
129        let (write, read) = ws_stream.split();
130        let reader = ByteReader::new(read);
131        let writer = ByteWriter::new(write);
132
133        Ok(Self {
134            inner: Box::new(WsStreamCompat { reader, writer }),
135        })
136    }
137
138    pub fn inner(&mut self) -> &mut dyn AsyncReadWrite {
139        &mut *self.inner
140    }
141}
142
143impl AsyncRead for Transport {
144    fn poll_read(
145        mut self: Pin<&mut Self>,
146        cx: &mut Context<'_>,
147        buf: &mut ReadBuf<'_>,
148    ) -> Poll<io::Result<()>> {
149        Pin::new(&mut *self.inner).poll_read(cx, buf)
150    }
151}
152
153impl AsyncWrite for Transport {
154    fn poll_write(
155        mut self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        buf: &[u8],
158    ) -> Poll<io::Result<usize>> {
159        Pin::new(&mut *self.inner).poll_write(cx, buf)
160    }
161
162    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
163        Pin::new(&mut *self.inner).poll_flush(cx)
164    }
165
166    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
167        Pin::new(&mut *self.inner).poll_shutdown(cx)
168    }
169}
170unsafe impl Send for Transport {
171
172}
173unsafe impl Sync for Transport {}