websocket_async_io/
lib.rs

1//! Implementations of [`AsyncRead`](https://docs.rs/futures/0.3.17/futures/io/trait.AsyncRead.html) and [`AsyncWrite`](https://docs.rs/futures/0.3.17/futures/io/trait.AsyncWrite.html) on top of websockets using [`web-sys`](https://github.com/rustwasm/wasm-bindgen/tree/master/crates/web-sys))
2//! # Example
3//! ```rust,no_run
4//! # async fn run() -> Result<(), std::io::Error> {
5//! let ws = WebsocketIO::new("localhost:8000").await?;
6//! let (mut reader, mut writer) = ws.split();
7//!
8//! writer.write_all(&[0, 1, 2, 3, 93]).await?;
9//! writer.write_all(&[42, 34, 93]).await?;
10//! writer.write_all(&[0, 0, 1, 2, 93]).await?;
11//!
12//! let mut buf = Vec::new();
13//! for _ in 0..3 {
14//!     reader.read_until(93, &mut buf).await?;
15//!     console_log!("{:?}", buf);
16//!     buf.clear();
17//! }
18//!
19//! # Ok(())
20//! # }
21/// ```
22use std::cmp::Ordering;
23use std::pin::Pin;
24use std::task::Poll;
25
26use futures_channel::mpsc::Receiver;
27use futures_core::stream::Stream;
28use futures_io::AsyncBufRead;
29use futures_io::AsyncRead;
30use futures_io::AsyncWrite;
31use js_sys::Uint8Array;
32use wasm_bindgen::prelude::*;
33use wasm_bindgen::JsCast;
34use web_sys::{ErrorEvent, MessageEvent, WebSocket};
35
36pub struct WebsocketIO {
37    ws: WebSocket,
38    reader: WebsocketReader,
39}
40
41struct WebsocketReader {
42    read_rx: Receiver<Uint8Array>,
43    remaining: Vec<u8>,
44}
45struct WebsocketWriter {
46    ws: WebSocket,
47}
48
49impl WebsocketIO {
50    pub async fn new(addr: &str) -> Result<WebsocketIO, std::io::Error> {
51        WebsocketIO::new_inner(&format!("ws://{}", addr)).await
52    }
53    pub async fn new_wss(addr: &str) -> Result<WebsocketIO, std::io::Error> {
54        WebsocketIO::new_inner(&format!("wss://{}", addr)).await
55    }
56
57    async fn new_inner(url: &str) -> Result<WebsocketIO, std::io::Error> {
58        let ws =
59            WebSocket::new(url).map_err(|e| -> std::io::Error { todo!("map error: {:?}", e) })?;
60
61        let buffer = 4;
62
63        let (open_tx, open_rx) = futures_channel::oneshot::channel();
64        let (read_tx, read_rx) = futures_channel::mpsc::channel(buffer);
65
66        let onmessage_callback = Closure::wrap(Box::new(move |e: MessageEvent| {
67            let mut read_tx = read_tx.clone();
68            let blob = match e.data().dyn_into::<web_sys::Blob>() {
69                Ok(blob) => blob,
70                _ => return,
71            };
72
73            let fr = web_sys::FileReader::new().unwrap();
74            let fr_c = fr.clone();
75            let file_reader_load_end = Closure::wrap(Box::new(move |_e: web_sys::ProgressEvent| {
76                let array = Uint8Array::new(&fr_c.result().unwrap());
77                read_tx.start_send(array).unwrap();
78            })
79                as Box<dyn FnMut(web_sys::ProgressEvent)>);
80            fr.set_onloadend(Some(file_reader_load_end.as_ref().unchecked_ref()));
81            file_reader_load_end.forget();
82
83            fr.read_as_array_buffer(&blob).expect("blob not readable");
84        }) as Box<dyn Fn(MessageEvent)>);
85
86        let onerror_callback =
87            Closure::wrap(Box::new(move |_: ErrorEvent| {}) as Box<dyn FnMut(ErrorEvent)>);
88
89        let mut open_tx = Some(open_tx);
90        let onopen_callback =
91            Closure::wrap(Box::new(move |_| open_tx.take().unwrap().send(()).unwrap())
92                as Box<dyn FnMut(JsValue)>);
93
94        ws.set_onmessage(Some(onmessage_callback.as_ref().unchecked_ref()));
95        onmessage_callback.forget();
96
97        ws.set_onerror(Some(onerror_callback.as_ref().unchecked_ref()));
98        onerror_callback.forget();
99
100        ws.set_onopen(Some(onopen_callback.as_ref().unchecked_ref()));
101        onopen_callback.forget();
102
103        let reader = WebsocketReader {
104            read_rx,
105            remaining: Vec::new(),
106        };
107
108        open_rx.await.unwrap();
109
110        let ws_io = WebsocketIO { ws, reader };
111        Ok(ws_io)
112    }
113
114    pub fn split(self) -> (impl AsyncBufRead, impl AsyncWrite) {
115        let WebsocketIO { ws, reader } = self;
116        (reader, WebsocketWriter { ws })
117    }
118}
119
120impl WebsocketReader {
121    fn write_remaining(&mut self, buf: &mut [u8]) -> usize {
122        match self.remaining.len().cmp(&buf.len()) {
123            Ordering::Less => {
124                let amount = self.remaining.len();
125                buf[0..amount].copy_from_slice(&self.remaining);
126                self.remaining.clear();
127                amount
128            }
129            Ordering::Equal => {
130                buf.copy_from_slice(&self.remaining);
131                self.remaining.clear();
132                buf.len()
133            }
134            Ordering::Greater => {
135                let amount = buf.len();
136                buf.copy_from_slice(&self.remaining[..amount]);
137                self.remaining.drain(0..amount);
138                amount
139            }
140        }
141    }
142}
143
144impl AsyncRead for WebsocketReader {
145    fn poll_read(
146        mut self: std::pin::Pin<&mut Self>,
147        cx: &mut std::task::Context<'_>,
148        buf: &mut [u8],
149    ) -> Poll<std::io::Result<usize>> {
150        if !self.remaining.is_empty() {
151            return Poll::Ready(Ok(self.write_remaining(buf)));
152        }
153
154        let array = match Pin::new(&mut self.read_rx).poll_next(cx) {
155            Poll::Ready(Some(item)) => item,
156            Poll::Ready(None) => return Poll::Pending,
157            Poll::Pending => return Poll::Pending,
158        };
159
160        let array_length = array.length() as usize;
161
162        let read = match array_length.cmp(&buf.len()) {
163            Ordering::Equal => {
164                array.copy_to(buf);
165                buf.len()
166            }
167            Ordering::Less => {
168                array.copy_to(&mut buf[..array_length]);
169                array_length
170            }
171            Ordering::Greater => {
172                self.remaining.resize(array_length, 0);
173                array.copy_to(self.as_mut().remaining.as_mut_slice());
174
175                self.write_remaining(buf)
176            }
177        };
178
179        Poll::Ready(Ok(read))
180    }
181}
182impl AsyncBufRead for WebsocketReader {
183    fn poll_fill_buf(
184        mut self: std::pin::Pin<&mut Self>,
185        cx: &mut std::task::Context<'_>,
186    ) -> Poll<futures_io::Result<&[u8]>> {
187        if !self.remaining.is_empty() {
188            return Poll::Ready(Ok(self.get_mut().remaining.as_slice()));
189        }
190
191        let array = match Pin::new(&mut self.read_rx).poll_next(cx) {
192            Poll::Ready(Some(item)) => item,
193            Poll::Ready(None) => return Poll::Pending,
194            Poll::Pending => return Poll::Pending,
195        };
196
197        self.remaining.extend(&array.to_vec());
198
199        if self.remaining.len() == 0 {
200            return Poll::Pending;
201        }
202        Poll::Ready(Ok(self.get_mut().remaining.as_slice()))
203    }
204
205    fn consume(mut self: std::pin::Pin<&mut Self>, amt: usize) {
206        if self.remaining.len() == amt {
207            self.remaining.clear();
208            return;
209        }
210        self.remaining.drain(0..amt);
211    }
212}
213
214impl AsyncWrite for WebsocketWriter {
215    fn poll_write(
216        self: std::pin::Pin<&mut Self>,
217        _: &mut std::task::Context<'_>,
218        buf: &[u8],
219    ) -> Poll<std::io::Result<usize>> {
220        self.ws.send_with_u8_array(buf).unwrap();
221
222        Poll::Ready(Ok(buf.len()))
223    }
224
225    fn poll_flush(
226        self: std::pin::Pin<&mut Self>,
227        _: &mut std::task::Context<'_>,
228    ) -> Poll<std::io::Result<()>> {
229        Poll::Ready(Ok(()))
230    }
231
232    fn poll_close(
233        self: std::pin::Pin<&mut Self>,
234        _: &mut std::task::Context<'_>,
235    ) -> Poll<std::io::Result<()>> {
236        self.ws.close().unwrap();
237        Poll::Ready(Ok(()))
238    }
239}