1#![cfg_attr(docsrs, feature(doc_cfg))]
58
59#[cfg(feature = "gloo")]
60#[cfg_attr(docsrs, doc(cfg(feature = "gloo")))]
61pub mod gloo;
62#[cfg(feature = "tungstenite")]
63#[cfg_attr(docsrs, doc(cfg(feature = "tungstenite")))]
64pub mod tungstenite;
65
66use futures::{ready, AsyncBufRead, AsyncRead, AsyncWrite, Sink, Stream};
67use pin_project::pin_project;
68use std::{io, marker::PhantomData, pin::Pin, task::Poll};
69#[cfg(feature = "tokio")]
70#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
71use tokio::io::{
72 AsyncBufRead as TokioAsyncBufRead, AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite,
73};
74
75#[derive(Debug)]
76pub enum WsMessageKind {
77 Bytes(Vec<u8>),
79 Close,
80 Other,
82}
83
84#[derive(Debug)]
85pub enum WsErrorKind {
86 Io(io::Error),
87 Closed,
89 AlreadyClosed,
90 Other(Box<dyn std::error::Error + Send + Sync>),
91}
92
93pub trait WsMessageHandle<Msg, E> {
95 fn message_into_kind(msg: Msg) -> WsMessageKind;
96 fn error_into_kind(e: E) -> WsErrorKind;
97 fn message_from_bytes<T: Into<Vec<u8>>>(bytes: T) -> Msg;
99
100 fn wrap_stream<S>(inner: S) -> WsByteStream<S, Msg, E, Self>
101 where
102 S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
103 {
104 WsByteStream::new(inner)
105 }
106}
107
108#[pin_project]
119pub struct WsByteStream<S, Msg, E, H>
120where
121 S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
122 H: WsMessageHandle<Msg, E> + ?Sized,
123{
124 #[pin]
125 inner: S,
126 state: State,
127 _marker: PhantomData<H>,
128}
129
130#[derive(Debug)]
131struct State {
132 read: ReadState,
133 write: WriteState,
134}
135
136#[derive(Debug)]
137enum ReadState {
138 Pending,
140 Ready { buf: Vec<u8>, amt_read: usize },
142 Terminated,
144}
145
146#[derive(Debug)]
147enum WriteState {
148 Ready,
149 Closed,
150}
151
152impl<S, Msg, E, H> WsByteStream<S, Msg, E, H>
153where
154 S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
155 H: WsMessageHandle<Msg, E> + ?Sized,
156{
157 pub fn new(inner: S) -> Self {
158 Self {
159 inner,
160 state: State {
161 read: ReadState::Pending,
162 write: WriteState::Ready,
163 },
164 _marker: PhantomData,
165 }
166 }
167
168 fn fill_buf_with_next_msg(
172 self: Pin<&mut Self>,
173 cx: &mut std::task::Context<'_>,
174 ) -> Poll<Option<io::Result<()>>> {
175 let mut this = self.project();
176 loop {
177 let res = ready!(this.inner.as_mut().poll_next(cx));
179 let Some(res) = res else {
180 this.state.read = ReadState::Terminated;
182 return Poll::Ready(None);
183 };
184 match res {
185 Ok(msg) => {
186 let msg = H::message_into_kind(msg);
187 match msg {
188 WsMessageKind::Bytes(msg) => {
189 this.state.read = ReadState::Ready {
190 buf: msg,
191 amt_read: 0,
192 };
193 return Poll::Ready(Some(Ok(())));
194 }
195 WsMessageKind::Close => {
196 this.state.read = ReadState::Terminated;
197 return Poll::Ready(None);
198 }
199 WsMessageKind::Other => {
200 continue;
202 }
203 }
204 }
205 Err(e) => {
206 let e = H::error_into_kind(e);
207 match e {
208 WsErrorKind::Io(e) => {
209 return Poll::Ready(Some(Err(e)));
210 }
211 WsErrorKind::Closed => {
212 this.state.read = ReadState::Terminated;
213 return Poll::Ready(None);
214 }
215 WsErrorKind::AlreadyClosed => {
216 this.state.read = ReadState::Terminated;
217 let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
218 return Poll::Ready(Some(Err(e)));
219 }
220 WsErrorKind::Other(e) => {
221 return Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))));
222 }
223 }
224 }
225 }
226 }
227 }
228}
229
230impl<S, Msg, E, H> AsyncRead for WsByteStream<S, Msg, E, H>
231where
232 S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
233 H: WsMessageHandle<Msg, E> + ?Sized,
234{
235 fn poll_read(
236 mut self: Pin<&mut Self>,
237 cx: &mut std::task::Context<'_>,
238 dst: &mut [u8],
239 ) -> Poll<io::Result<usize>> {
240 loop {
241 let this = self.as_mut().project();
242 match this.state.read {
243 ReadState::Pending => {
244 let res = ready!(self.as_mut().fill_buf_with_next_msg(cx));
245 match res {
246 Some(Ok(())) => continue, Some(Err(e)) => return Poll::Ready(Err(e)),
248 None => continue, }
250 }
251 ReadState::Ready {
252 ref buf,
253 ref mut amt_read,
254 } => {
255 let buf = &buf[*amt_read..];
256 let len = std::cmp::min(dst.len(), buf.len());
257 dst[..len].copy_from_slice(&buf[..len]);
258 if len == buf.len() {
259 this.state.read = ReadState::Pending;
260 } else {
261 *amt_read += len;
262 }
263 return Poll::Ready(Ok(len));
264 }
265 ReadState::Terminated => {
266 return Poll::Ready(Ok(0));
267 }
268 }
269 }
270 }
271}
272
273impl<S, Msg, E, H> AsyncBufRead for WsByteStream<S, Msg, E, H>
274where
275 S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
276 H: WsMessageHandle<Msg, E> + ?Sized,
277{
278 fn poll_fill_buf(
279 mut self: Pin<&mut Self>,
280 cx: &mut std::task::Context<'_>,
281 ) -> Poll<io::Result<&[u8]>> {
282 loop {
283 let this = self.as_mut().project();
284 match this.state.read {
285 ReadState::Pending => {
286 let res = ready!(self.as_mut().fill_buf_with_next_msg(cx));
287 match res {
288 Some(Ok(())) => continue, Some(Err(e)) => return Poll::Ready(Err(e)),
290 None => continue, }
292 }
293 ReadState::Ready { .. } => {
294 let this = self.project();
296 let ReadState::Ready { ref buf, amt_read } = this.state.read else {
297 unreachable!()
298 };
299 return Poll::Ready(Ok(&buf[amt_read..]));
300 }
301 ReadState::Terminated => {
302 return Poll::Ready(Ok(&[]));
303 }
304 }
305 }
306 }
307
308 fn consume(mut self: Pin<&mut Self>, amt: usize) {
309 let ReadState::Ready {
310 ref buf,
311 ref mut amt_read,
312 } = self.state.read
313 else {
314 return;
315 };
316 *amt_read = std::cmp::min(buf.len(), *amt_read + amt);
317 if *amt_read == buf.len() {
318 self.state.read = ReadState::Pending;
319 }
320 }
321}
322
323impl<S, Msg, E, H> AsyncWrite for WsByteStream<S, Msg, E, H>
324where
325 S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
326 H: WsMessageHandle<Msg, E> + ?Sized,
327{
328 fn poll_write(
329 mut self: Pin<&mut Self>,
330 cx: &mut std::task::Context<'_>,
331 buf: &[u8],
332 ) -> Poll<io::Result<usize>> {
333 let mut this = self.as_mut().project();
334 loop {
335 match this.state.write {
336 WriteState::Ready => {
337 if let Err(e) = ready!(this.inner.as_mut().poll_ready(cx)) {
338 let e = H::error_into_kind(e);
339 match e {
340 WsErrorKind::Io(e) => {
341 return Poll::Ready(Err(e));
342 }
343 WsErrorKind::Other(e) => {
344 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)));
345 }
346 WsErrorKind::Closed => {
347 this.state.write = WriteState::Closed;
348 return Poll::Ready(Ok(0));
349 }
350 WsErrorKind::AlreadyClosed => {
351 this.state.write = WriteState::Closed;
352 let e =
353 io::Error::new(io::ErrorKind::NotConnected, "Already closed");
354 return Poll::Ready(Err(e));
355 }
356 }
357 }
358 let Err(e) = this.inner.as_mut().start_send(H::message_from_bytes(buf)) else {
360 this.state.write = WriteState::Ready;
361 return Poll::Ready(Ok(buf.len()));
362 };
363 let e = H::error_into_kind(e);
364 match e {
365 WsErrorKind::Io(e) => {
366 return Poll::Ready(Err(e));
367 }
368 WsErrorKind::Other(e) => {
369 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)));
370 }
371 WsErrorKind::Closed => {
372 this.state.write = WriteState::Closed;
373 return Poll::Ready(Ok(0));
374 }
375 WsErrorKind::AlreadyClosed => {
376 this.state.write = WriteState::Closed;
377 let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
378 return Poll::Ready(Err(e));
379 }
380 }
381 }
382 WriteState::Closed => {
383 let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
384 return Poll::Ready(Err(e));
385 }
386 }
387 }
388 }
389
390 fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
391 let mut this = self.project();
392 if let Err(e) = ready!(this.inner.as_mut().poll_flush(cx)) {
393 let e = H::error_into_kind(e);
394 match e {
395 WsErrorKind::Io(e) => {
396 return Poll::Ready(Err(e));
397 }
398 WsErrorKind::Other(e) => {
399 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)));
400 }
401 WsErrorKind::Closed => {
402 this.state.write = WriteState::Closed;
403 return Poll::Ready(Ok(()));
404 }
405 WsErrorKind::AlreadyClosed => {
406 this.state.write = WriteState::Closed;
407 let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
408 return Poll::Ready(Err(e));
409 }
410 }
411 }
412 Poll::Ready(Ok(()))
413 }
414
415 fn poll_close(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
416 let mut this = self.project();
417 this.state.write = WriteState::Closed;
418 if let Err(e) = ready!(this.inner.as_mut().poll_close(cx)) {
419 let e = H::error_into_kind(e);
420 match e {
421 WsErrorKind::Io(e) => {
422 return Poll::Ready(Err(e));
423 }
424 WsErrorKind::Other(e) => {
425 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)));
426 }
427 WsErrorKind::Closed => {
428 return Poll::Ready(Ok(()));
429 }
430 WsErrorKind::AlreadyClosed => {
431 let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
432 return Poll::Ready(Err(e));
433 }
434 }
435 }
436 Poll::Ready(Ok(()))
437 }
438}
439
440#[cfg(feature = "tokio")]
441#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
442impl<S, Msg, E, H> TokioAsyncRead for WsByteStream<S, Msg, E, H>
443where
444 S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
445 H: WsMessageHandle<Msg, E> + ?Sized,
446{
447 fn poll_read(
448 self: Pin<&mut Self>,
449 cx: &mut std::task::Context<'_>,
450 buf: &mut tokio::io::ReadBuf<'_>,
451 ) -> Poll<io::Result<()>> {
452 let slice = buf.initialize_unfilled();
453 let n = ready!(AsyncRead::poll_read(self, cx, slice))?;
454 buf.advance(n);
455 Poll::Ready(Ok(()))
456 }
457}
458
459#[cfg(feature = "tokio")]
460#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
461impl<S, Msg, E, H> TokioAsyncBufRead for WsByteStream<S, Msg, E, H>
462where
463 S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
464 H: WsMessageHandle<Msg, E> + ?Sized,
465{
466 fn poll_fill_buf(
467 self: Pin<&mut Self>,
468 cx: &mut std::task::Context<'_>,
469 ) -> Poll<io::Result<&[u8]>> {
470 AsyncBufRead::poll_fill_buf(self, cx)
471 }
472
473 fn consume(self: Pin<&mut Self>, amt: usize) {
474 AsyncBufRead::consume(self, amt)
475 }
476}
477
478#[cfg(feature = "tokio")]
479#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
480impl<S, Msg, E, H> TokioAsyncWrite for WsByteStream<S, Msg, E, H>
481where
482 S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
483 H: WsMessageHandle<Msg, E> + ?Sized,
484{
485 fn poll_write(
486 self: Pin<&mut Self>,
487 cx: &mut std::task::Context<'_>,
488 buf: &[u8],
489 ) -> Poll<Result<usize, io::Error>> {
490 AsyncWrite::poll_write(self, cx, buf)
491 }
492
493 fn poll_flush(
494 self: Pin<&mut Self>,
495 cx: &mut std::task::Context<'_>,
496 ) -> Poll<Result<(), io::Error>> {
497 AsyncWrite::poll_flush(self, cx)
498 }
499
500 fn poll_shutdown(
501 self: Pin<&mut Self>,
502 cx: &mut std::task::Context<'_>,
503 ) -> Poll<Result<(), io::Error>> {
504 AsyncWrite::poll_close(self, cx)
505 }
506}