tfserver/structures/
transport.rs1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use async_tungstenite::{ByteReader, ByteWriter};
5use futures_util::{StreamExt};
6use pin_project::pin_project;
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9#[cfg(not(target_arch = "wasm32"))]
10use tokio::net::TcpStream;
11#[cfg(not(target_arch = "wasm32"))]
12use tokio_rustls::{client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream};
13use tokio_tungstenite::{accept_async, connect_async};
14
15
16pub struct Transport {
17 inner: Box<dyn AsyncReadWrite>,
18}
19
20pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + 'static {
21 #[cfg(not(target_arch = "wasm32"))]
22 fn is_send_sync(&self) where Self: Send + Sync {}
23}
24
25impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static> AsyncReadWrite for T {}
26
27
28
29#[pin_project]
30pub struct WsStreamCompat<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> {
31 #[pin]
32 reader: R,
33 #[pin]
34 writer: W,
35}
36
37impl<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> AsyncRead
38for WsStreamCompat<R, W>
39{
40 fn poll_read(
41 self: Pin<&mut Self>,
42 cx: &mut Context<'_>,
43 buf: &mut ReadBuf<'_>,
44 ) -> Poll<io::Result<()>> {
45 let unfilled = buf.initialize_unfilled();
46 match self.project().reader.poll_read(cx, unfilled) {
47 Poll::Ready(Ok(n)) => {
48 buf.advance(n);
49 Poll::Ready(Ok(()))
50 }
51 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
52 Poll::Pending => Poll::Pending,
53 }
54 }
55}
56
57impl<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> AsyncWrite
58for WsStreamCompat<R, W>
59{
60 fn poll_write(
61 self: Pin<&mut Self>,
62 cx: &mut Context<'_>,
63 buf: &[u8],
64 ) -> Poll<io::Result<usize>> {
65 self.project().writer.poll_write(cx, buf)
66 }
67
68 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
69 self.project().writer.poll_flush(cx)
70 }
71
72 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
73 self.project().writer.poll_close(cx)
74 }
75}
76
77impl Transport {
78 #[cfg(not(target_arch = "wasm32"))]
79 pub fn plain(stream: TcpStream) -> Self {
80 Self { inner: Box::new(stream) }
81 }
82
83 #[cfg(not(target_arch = "wasm32"))]
84 pub fn tls_server(stream: ServerTlsStream<TcpStream>) -> Self {
85 Self { inner: Box::new(stream) }
86 }
87
88 #[cfg(not(target_arch = "wasm32"))]
89 pub fn tls_client(stream: ClientTlsStream<TcpStream>) -> Self {
90 Self { inner: Box::new(stream) }
91 }
92
93 #[cfg(target_arch = "wasm32")]
96 pub async fn connect(url: &str) -> io::Result<Self> {
97 let (_meta, ws_stream) = WsMeta::connect(url, None)
98 .await
99 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
100
101 let (reader, writer) = ws_stream.into_io().split();
102 Ok(Self {
103 inner: Box::new(WsStreamCompat { reader, writer }),
104 })
105 }
106
107 #[cfg(not(target_arch = "wasm32"))]
108 pub async fn connect(url: &str) -> io::Result<Self> {
109 let (ws_stream, _response) = connect_async(url)
110 .await
111 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()))?;
112
113 let (write, read) = ws_stream.split();
114 let reader = ByteReader::new(read);
115 let writer = ByteWriter::new(write);
116
117 Ok(Self {
118 inner: Box::new(WsStreamCompat { reader, writer }),
119 })
120 }
121
122 #[cfg(not(target_arch = "wasm32"))]
123 pub async fn accept_websocket(stream: Transport) -> io::Result<Self> {
124 let ws_stream = accept_async(stream)
125 .await
126 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
127
128 let (write, read) = ws_stream.split();
129 let reader = ByteReader::new(read);
130 let writer = ByteWriter::new(write);
131
132 Ok(Self {
133 inner: Box::new(WsStreamCompat { reader, writer }),
134 })
135 }
136
137 pub fn inner(&mut self) -> &mut dyn AsyncReadWrite {
138 &mut *self.inner
139 }
140}
141
142impl AsyncRead for Transport {
143 fn poll_read(
144 mut self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 buf: &mut ReadBuf<'_>,
147 ) -> Poll<io::Result<()>> {
148 Pin::new(&mut *self.inner).poll_read(cx, buf)
149 }
150}
151
152impl AsyncWrite for Transport {
153 fn poll_write(
154 mut self: Pin<&mut Self>,
155 cx: &mut Context<'_>,
156 buf: &[u8],
157 ) -> Poll<io::Result<usize>> {
158 Pin::new(&mut *self.inner).poll_write(cx, buf)
159 }
160
161 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
162 Pin::new(&mut *self.inner).poll_flush(cx)
163 }
164
165 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166 Pin::new(&mut *self.inner).poll_shutdown(cx)
167 }
168}
169unsafe impl Send for Transport {
170
171}
172unsafe impl Sync for Transport {}