websocket_async_io/
lib.rs1use std::cmp::Ordering;
23use std::pin::Pin;
24use std::task::Poll;
25
26use futures_channel::mpsc::Receiver;
27use futures_core::stream::Stream;
28use futures_io::AsyncBufRead;
29use futures_io::AsyncRead;
30use futures_io::AsyncWrite;
31use js_sys::Uint8Array;
32use wasm_bindgen::prelude::*;
33use wasm_bindgen::JsCast;
34use web_sys::{ErrorEvent, MessageEvent, WebSocket};
35
36pub struct WebsocketIO {
37 ws: WebSocket,
38 reader: WebsocketReader,
39}
40
41struct WebsocketReader {
42 read_rx: Receiver<Uint8Array>,
43 remaining: Vec<u8>,
44}
45struct WebsocketWriter {
46 ws: WebSocket,
47}
48
49impl WebsocketIO {
50 pub async fn new(addr: &str) -> Result<WebsocketIO, std::io::Error> {
51 WebsocketIO::new_inner(&format!("ws://{}", addr)).await
52 }
53 pub async fn new_wss(addr: &str) -> Result<WebsocketIO, std::io::Error> {
54 WebsocketIO::new_inner(&format!("wss://{}", addr)).await
55 }
56
57 async fn new_inner(url: &str) -> Result<WebsocketIO, std::io::Error> {
58 let ws =
59 WebSocket::new(url).map_err(|e| -> std::io::Error { todo!("map error: {:?}", e) })?;
60
61 let buffer = 4;
62
63 let (open_tx, open_rx) = futures_channel::oneshot::channel();
64 let (read_tx, read_rx) = futures_channel::mpsc::channel(buffer);
65
66 let onmessage_callback = Closure::wrap(Box::new(move |e: MessageEvent| {
67 let mut read_tx = read_tx.clone();
68 let blob = match e.data().dyn_into::<web_sys::Blob>() {
69 Ok(blob) => blob,
70 _ => return,
71 };
72
73 let fr = web_sys::FileReader::new().unwrap();
74 let fr_c = fr.clone();
75 let file_reader_load_end = Closure::wrap(Box::new(move |_e: web_sys::ProgressEvent| {
76 let array = Uint8Array::new(&fr_c.result().unwrap());
77 read_tx.start_send(array).unwrap();
78 })
79 as Box<dyn FnMut(web_sys::ProgressEvent)>);
80 fr.set_onloadend(Some(file_reader_load_end.as_ref().unchecked_ref()));
81 file_reader_load_end.forget();
82
83 fr.read_as_array_buffer(&blob).expect("blob not readable");
84 }) as Box<dyn Fn(MessageEvent)>);
85
86 let onerror_callback =
87 Closure::wrap(Box::new(move |_: ErrorEvent| {}) as Box<dyn FnMut(ErrorEvent)>);
88
89 let mut open_tx = Some(open_tx);
90 let onopen_callback =
91 Closure::wrap(Box::new(move |_| open_tx.take().unwrap().send(()).unwrap())
92 as Box<dyn FnMut(JsValue)>);
93
94 ws.set_onmessage(Some(onmessage_callback.as_ref().unchecked_ref()));
95 onmessage_callback.forget();
96
97 ws.set_onerror(Some(onerror_callback.as_ref().unchecked_ref()));
98 onerror_callback.forget();
99
100 ws.set_onopen(Some(onopen_callback.as_ref().unchecked_ref()));
101 onopen_callback.forget();
102
103 let reader = WebsocketReader {
104 read_rx,
105 remaining: Vec::new(),
106 };
107
108 open_rx.await.unwrap();
109
110 let ws_io = WebsocketIO { ws, reader };
111 Ok(ws_io)
112 }
113
114 pub fn split(self) -> (impl AsyncBufRead, impl AsyncWrite) {
115 let WebsocketIO { ws, reader } = self;
116 (reader, WebsocketWriter { ws })
117 }
118}
119
120impl WebsocketReader {
121 fn write_remaining(&mut self, buf: &mut [u8]) -> usize {
122 match self.remaining.len().cmp(&buf.len()) {
123 Ordering::Less => {
124 let amount = self.remaining.len();
125 buf[0..amount].copy_from_slice(&self.remaining);
126 self.remaining.clear();
127 amount
128 }
129 Ordering::Equal => {
130 buf.copy_from_slice(&self.remaining);
131 self.remaining.clear();
132 buf.len()
133 }
134 Ordering::Greater => {
135 let amount = buf.len();
136 buf.copy_from_slice(&self.remaining[..amount]);
137 self.remaining.drain(0..amount);
138 amount
139 }
140 }
141 }
142}
143
144impl AsyncRead for WebsocketReader {
145 fn poll_read(
146 mut self: std::pin::Pin<&mut Self>,
147 cx: &mut std::task::Context<'_>,
148 buf: &mut [u8],
149 ) -> Poll<std::io::Result<usize>> {
150 if !self.remaining.is_empty() {
151 return Poll::Ready(Ok(self.write_remaining(buf)));
152 }
153
154 let array = match Pin::new(&mut self.read_rx).poll_next(cx) {
155 Poll::Ready(Some(item)) => item,
156 Poll::Ready(None) => return Poll::Pending,
157 Poll::Pending => return Poll::Pending,
158 };
159
160 let array_length = array.length() as usize;
161
162 let read = match array_length.cmp(&buf.len()) {
163 Ordering::Equal => {
164 array.copy_to(buf);
165 buf.len()
166 }
167 Ordering::Less => {
168 array.copy_to(&mut buf[..array_length]);
169 array_length
170 }
171 Ordering::Greater => {
172 self.remaining.resize(array_length, 0);
173 array.copy_to(self.as_mut().remaining.as_mut_slice());
174
175 self.write_remaining(buf)
176 }
177 };
178
179 Poll::Ready(Ok(read))
180 }
181}
182impl AsyncBufRead for WebsocketReader {
183 fn poll_fill_buf(
184 mut self: std::pin::Pin<&mut Self>,
185 cx: &mut std::task::Context<'_>,
186 ) -> Poll<futures_io::Result<&[u8]>> {
187 if !self.remaining.is_empty() {
188 return Poll::Ready(Ok(self.get_mut().remaining.as_slice()));
189 }
190
191 let array = match Pin::new(&mut self.read_rx).poll_next(cx) {
192 Poll::Ready(Some(item)) => item,
193 Poll::Ready(None) => return Poll::Pending,
194 Poll::Pending => return Poll::Pending,
195 };
196
197 self.remaining.extend(&array.to_vec());
198
199 if self.remaining.len() == 0 {
200 return Poll::Pending;
201 }
202 Poll::Ready(Ok(self.get_mut().remaining.as_slice()))
203 }
204
205 fn consume(mut self: std::pin::Pin<&mut Self>, amt: usize) {
206 if self.remaining.len() == amt {
207 self.remaining.clear();
208 return;
209 }
210 self.remaining.drain(0..amt);
211 }
212}
213
214impl AsyncWrite for WebsocketWriter {
215 fn poll_write(
216 self: std::pin::Pin<&mut Self>,
217 _: &mut std::task::Context<'_>,
218 buf: &[u8],
219 ) -> Poll<std::io::Result<usize>> {
220 self.ws.send_with_u8_array(buf).unwrap();
221
222 Poll::Ready(Ok(buf.len()))
223 }
224
225 fn poll_flush(
226 self: std::pin::Pin<&mut Self>,
227 _: &mut std::task::Context<'_>,
228 ) -> Poll<std::io::Result<()>> {
229 Poll::Ready(Ok(()))
230 }
231
232 fn poll_close(
233 self: std::pin::Pin<&mut Self>,
234 _: &mut std::task::Context<'_>,
235 ) -> Poll<std::io::Result<()>> {
236 self.ws.close().unwrap();
237 Poll::Ready(Ok(()))
238 }
239}