tfserver/structures/
transport.rs1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use futures_util::{StreamExt};
5use pin_project::pin_project;
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7
8#[cfg(not(target_arch = "wasm32"))]
9use async_tungstenite::{ByteReader, ByteWriter};
10#[cfg(not(target_arch = "wasm32"))]
11use tokio::net::TcpStream;
12#[cfg(not(target_arch = "wasm32"))]
13use tokio_rustls::{client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream};
14#[cfg(not(target_arch = "wasm32"))]
15use tokio_tungstenite::{accept_async, connect_async};
16
17
18pub struct Transport {
19 inner: Box<dyn AsyncReadWrite>,
20}
21
22pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + 'static {
23 #[cfg(not(target_arch = "wasm32"))]
24 fn is_send_sync(&self) where Self: Send + Sync {}
25}
26
27impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static> AsyncReadWrite for T {}
28
29
30
31#[pin_project]
32pub struct WsStreamCompat<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> {
33 #[pin]
34 reader: R,
35 #[pin]
36 writer: W,
37}
38
39impl<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> AsyncRead
40for WsStreamCompat<R, W>
41{
42 fn poll_read(
43 self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 buf: &mut ReadBuf<'_>,
46 ) -> Poll<io::Result<()>> {
47 let unfilled = buf.initialize_unfilled();
48 match self.project().reader.poll_read(cx, unfilled) {
49 Poll::Ready(Ok(n)) => {
50 buf.advance(n);
51 Poll::Ready(Ok(()))
52 }
53 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
54 Poll::Pending => Poll::Pending,
55 }
56 }
57}
58
59impl<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> AsyncWrite
60for WsStreamCompat<R, W>
61{
62 fn poll_write(
63 self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 buf: &[u8],
66 ) -> Poll<io::Result<usize>> {
67 self.project().writer.poll_write(cx, buf)
68 }
69
70 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71 self.project().writer.poll_flush(cx)
72 }
73
74 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
75 self.project().writer.poll_close(cx)
76 }
77}
78
79impl Transport {
80 #[cfg(not(target_arch = "wasm32"))]
81 pub fn plain(stream: TcpStream) -> Self {
82 Self { inner: Box::new(stream) }
83 }
84
85
86 #[cfg(not(target_arch = "wasm32"))]
87 pub fn tls_server(stream: ServerTlsStream<TcpStream>) -> Self {
88 Self { inner: Box::new(stream) }
89 }
90
91 #[cfg(not(target_arch = "wasm32"))]
92 pub fn tls_client(stream: ClientTlsStream<TcpStream>) -> Self {
93 Self { inner: Box::new(stream) }
94 }
95
96 #[cfg(target_arch = "wasm32")]
99 pub async fn connect(url: &str) -> io::Result<Self> {
100 use ws_stream_wasm::WsMeta;
101 use futures_util::AsyncReadExt;
102 let (_meta, ws_stream) = WsMeta::connect(url, None)
103 .await
104 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
105
106 let (reader, writer) = ws_stream.into_io().split();
107 Ok(Self {
108 inner: Box::new(WsStreamCompat { reader, writer }),
109 })
110 }
111
112 #[cfg(not(target_arch = "wasm32"))]
113 pub async fn connect(url: &str) -> io::Result<Self> {
114 let (ws_stream, _response) = connect_async(url)
115 .await
116 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()))?;
117
118 let (write, read) = ws_stream.split();
119 let reader = ByteReader::new(read);
120 let writer = ByteWriter::new(write);
121
122 Ok(Self {
123 inner: Box::new(WsStreamCompat { reader, writer }),
124 })
125 }
126
127 #[cfg(not(target_arch = "wasm32"))]
128 pub async fn accept_websocket(stream: Transport) -> io::Result<Self> {
129 let ws_stream = accept_async(stream)
130 .await
131 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
132
133 let (write, read) = ws_stream.split();
134 let reader = ByteReader::new(read);
135 let writer = ByteWriter::new(write);
136
137 Ok(Self {
138 inner: Box::new(WsStreamCompat { reader, writer }),
139 })
140 }
141
142 pub fn inner(&mut self) -> &mut dyn AsyncReadWrite {
143 &mut *self.inner
144 }
145}
146
147impl AsyncRead for Transport {
148 fn poll_read(
149 mut self: Pin<&mut Self>,
150 cx: &mut Context<'_>,
151 buf: &mut ReadBuf<'_>,
152 ) -> Poll<io::Result<()>> {
153 Pin::new(&mut *self.inner).poll_read(cx, buf)
154 }
155}
156
157impl AsyncWrite for Transport {
158 fn poll_write(
159 mut self: Pin<&mut Self>,
160 cx: &mut Context<'_>,
161 buf: &[u8],
162 ) -> Poll<io::Result<usize>> {
163 Pin::new(&mut *self.inner).poll_write(cx, buf)
164 }
165
166 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
167 Pin::new(&mut *self.inner).poll_flush(cx)
168 }
169
170 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
171 Pin::new(&mut *self.inner).poll_shutdown(cx)
172 }
173}
174unsafe impl Send for Transport {
175
176}
177unsafe impl Sync for Transport {}