1#![warn(missing_docs)]
59#[cfg(not(target_family = "wasm"))]
60compile_error!("websocket-web requires a WebAssembly target");
61
62mod closed;
63mod standard;
64mod stream;
65mod util;
66
67use futures_core::Stream;
68use futures_sink::Sink;
69use futures_util::{SinkExt, StreamExt};
70use js_sys::{Reflect, Uint8Array};
71use std::{
72 fmt, io,
73 io::ErrorKind,
74 mem,
75 pin::Pin,
76 rc::Rc,
77 task::{ready, Context, Poll},
78};
79use tokio::io::{AsyncRead, AsyncWrite};
80use wasm_bindgen::prelude::*;
81
82pub use closed::{CloseCode, Closed, ClosedReason};
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum Interface {
87 Stream,
91 Standard,
95}
96
97impl Interface {
98 pub fn is_supported(&self) -> bool {
100 let global = js_sys::global();
101 match self {
102 Self::Stream => Reflect::has(&global, &JsValue::from_str("WebSocketStream")).unwrap_or_default(),
103 Self::Standard => Reflect::has(&global, &JsValue::from_str("WebSocket")).unwrap_or_default(),
104 }
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum Msg {
111 Text(String),
113 Binary(Vec<u8>),
115}
116
117impl Msg {
118 pub const fn is_text(&self) -> bool {
120 matches!(self, Self::Text(_))
121 }
122
123 pub const fn is_binary(&self) -> bool {
125 matches!(self, Self::Binary(_))
126 }
127
128 pub fn to_vec(self) -> Vec<u8> {
130 match self {
131 Self::Text(text) => text.as_bytes().to_vec(),
132 Self::Binary(vec) => vec,
133 }
134 }
135
136 pub fn len(&self) -> usize {
138 match self {
139 Self::Text(text) => text.len(),
140 Self::Binary(vec) => vec.len(),
141 }
142 }
143
144 pub fn is_empty(&self) -> bool {
146 match self {
147 Self::Text(text) => text.is_empty(),
148 Self::Binary(vec) => vec.is_empty(),
149 }
150 }
151}
152
153impl fmt::Display for Msg {
154 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
155 match self {
156 Self::Text(text) => write!(f, "{text}"),
157 Self::Binary(binary) => write!(f, "{}", String::from_utf8_lossy(binary)),
158 }
159 }
160}
161
162impl From<Msg> for Vec<u8> {
163 fn from(msg: Msg) -> Self {
164 msg.to_vec()
165 }
166}
167
168impl AsRef<[u8]> for Msg {
169 fn as_ref(&self) -> &[u8] {
170 match self {
171 Self::Text(text) => text.as_bytes(),
172 Self::Binary(vec) => vec,
173 }
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct WebSocketBuilder {
180 url: String,
181 protocols: Vec<String>,
182 interface: Option<Interface>,
183 send_buffer_size: Option<usize>,
184 receive_buffer_size: Option<usize>,
185}
186
187impl WebSocketBuilder {
188 pub fn new(url: impl AsRef<str>) -> Self {
190 Self {
191 url: url.as_ref().to_string(),
192 protocols: Vec::new(),
193 interface: None,
194 send_buffer_size: None,
195 receive_buffer_size: None,
196 }
197 }
198
199 pub fn set_interface(&mut self, interface: Interface) {
203 self.interface = Some(interface);
204 }
205
206 pub fn set_protocols<P>(&mut self, protocols: impl IntoIterator<Item = P>)
218 where
219 P: AsRef<str>,
220 {
221 self.protocols = protocols.into_iter().map(|s| s.as_ref().to_string()).collect();
222 }
223
224 pub fn set_send_buffer_size(&mut self, send_buffer_size: usize) {
231 self.send_buffer_size = Some(send_buffer_size);
232 }
233
234 pub fn set_receive_buffer_size(&mut self, receive_buffer_size: usize) {
241 self.receive_buffer_size = Some(receive_buffer_size);
242 }
243
244 pub async fn connect(self) -> io::Result<WebSocket> {
246 let interface = match self.interface {
247 Some(interface) => interface,
248 None if Interface::Stream.is_supported() => Interface::Stream,
249 None => Interface::Standard,
250 };
251
252 if !interface.is_supported() {
253 match interface {
254 Interface::Stream => {
255 return Err(io::Error::new(ErrorKind::Unsupported, "WebSocketStream not supported"))
256 }
257 Interface::Standard => {
258 return Err(io::Error::new(ErrorKind::Unsupported, "WebSocket not supported"))
259 }
260 }
261 }
262
263 match interface {
264 Interface::Stream => {
265 let (stream, info) = stream::Inner::new(self).await?;
266 Ok(WebSocket { inner: Inner::Stream(stream), info: Rc::new(info), read_buf: Vec::new() })
267 }
268 Interface::Standard => {
269 let (standard, info) = standard::Inner::new(self).await?;
270 Ok(WebSocket { inner: Inner::Standard(standard), info: Rc::new(info), read_buf: Vec::new() })
271 }
272 }
273 }
274}
275
276struct Info {
277 url: String,
278 protocol: String,
279 interface: Interface,
280}
281
282pub struct WebSocket {
286 inner: Inner,
287 info: Rc<Info>,
288 read_buf: Vec<u8>,
289}
290
291enum Inner {
292 Stream(stream::Inner),
293 Standard(standard::Inner),
294}
295
296impl fmt::Debug for WebSocket {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 f.debug_struct("WebSocket")
299 .field("url", &self.info.url)
300 .field("protocol", &self.protocol())
301 .field("interface", &self.interface())
302 .finish()
303 }
304}
305
306impl WebSocket {
307 pub async fn connect(url: impl AsRef<str>) -> io::Result<Self> {
309 WebSocketBuilder::new(url).connect().await
310 }
311
312 pub fn url(&self) -> &str {
314 &self.info.url
315 }
316
317 pub fn protocol(&self) -> &str {
323 &self.info.protocol
324 }
325
326 pub fn interface(&self) -> Interface {
328 self.info.interface
329 }
330
331 pub fn into_split(self) -> (WebSocketSender, WebSocketReceiver) {
333 let Self { inner, info, read_buf } = self;
334 match inner {
335 Inner::Stream(inner) => {
336 let (sender, receiver) = inner.into_split();
337 let sender = WebSocketSender { inner: SenderInner::Stream(sender), info: info.clone() };
338 let receiver = WebSocketReceiver { inner: ReceiverInner::Stream(receiver), info, read_buf };
339 (sender, receiver)
340 }
341 Inner::Standard(inner) => {
342 let (sender, receiver) = inner.into_split();
343 let sender = WebSocketSender { inner: SenderInner::Standard(sender), info: info.clone() };
344 let receiver =
345 WebSocketReceiver { inner: ReceiverInner::Standard(receiver), info, read_buf: Vec::new() };
346 (sender, receiver)
347 }
348 }
349 }
350
351 pub fn close(self) {
353 self.into_split().0.close();
354 }
355
356 #[track_caller]
362 pub fn close_with_reason(self, code: CloseCode, reason: &str) {
363 self.into_split().0.close_with_reason(code, reason);
364 }
365
366 pub fn closed(&self) -> Closed {
368 match &self.inner {
369 Inner::Stream(inner) => inner.closed(),
370 Inner::Standard(inner) => inner.closed(),
371 }
372 }
373
374 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
375 match &mut self.inner {
376 Inner::Stream(inner) => inner.sender.poll_ready_unpin(cx),
377 Inner::Standard(inner) => inner.sender.poll_ready_unpin(cx),
378 }
379 }
380
381 fn start_send(mut self: Pin<&mut Self>, item: &JsValue) -> Result<(), io::Error> {
382 match &mut self.inner {
383 Inner::Stream(inner) => inner.sender.start_send_unpin(item),
384 Inner::Standard(inner) => inner.sender.start_send_unpin(item),
385 }
386 }
387
388 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
389 match &mut self.inner {
390 Inner::Stream(inner) => inner.sender.poll_flush_unpin(cx),
391 Inner::Standard(inner) => inner.sender.poll_flush_unpin(cx),
392 }
393 }
394
395 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
396 match &mut self.inner {
397 Inner::Stream(inner) => inner.sender.poll_close_unpin(cx),
398 Inner::Standard(inner) => inner.sender.poll_close_unpin(cx),
399 }
400 }
401}
402
403impl Sink<&str> for WebSocket {
404 type Error = io::Error;
405
406 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
407 self.poll_ready(cx)
408 }
409
410 fn start_send(self: Pin<&mut Self>, item: &str) -> Result<(), Self::Error> {
411 self.start_send(&JsValue::from_str(item))
412 }
413
414 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
415 self.poll_flush(cx)
416 }
417
418 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
419 self.poll_close(cx)
420 }
421}
422
423impl Sink<String> for WebSocket {
424 type Error = io::Error;
425
426 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
427 self.poll_ready(cx)
428 }
429
430 fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
431 self.start_send(&JsValue::from_str(&item))
432 }
433
434 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
435 self.poll_flush(cx)
436 }
437
438 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
439 self.poll_close(cx)
440 }
441}
442
443impl Sink<&[u8]> for WebSocket {
444 type Error = io::Error;
445
446 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
447 self.poll_ready(cx)
448 }
449
450 fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
451 self.start_send(&Uint8Array::from(item))
452 }
453
454 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
455 self.poll_flush(cx)
456 }
457
458 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
459 self.poll_close(cx)
460 }
461}
462
463impl Sink<Vec<u8>> for WebSocket {
464 type Error = io::Error;
465
466 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
467 self.poll_ready(cx)
468 }
469
470 fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
471 self.start_send(&Uint8Array::from(&item[..]))
472 }
473
474 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
475 self.poll_flush(cx)
476 }
477
478 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
479 self.poll_close(cx)
480 }
481}
482
483impl Sink<Msg> for WebSocket {
484 type Error = io::Error;
485
486 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
487 self.poll_ready(cx)
488 }
489
490 fn start_send(self: Pin<&mut Self>, item: Msg) -> Result<(), Self::Error> {
491 match item {
492 Msg::Text(text) => self.start_send(&JsValue::from_str(&text)),
493 Msg::Binary(vec) => self.start_send(&Uint8Array::from(&vec[..])),
494 }
495 }
496
497 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
498 self.poll_flush(cx)
499 }
500
501 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
502 self.poll_close(cx)
503 }
504}
505
506impl AsyncWrite for WebSocket {
507 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
508 ready!(self.as_mut().poll_ready(cx))?;
509 self.start_send(&Uint8Array::from(buf))?;
510 Poll::Ready(Ok(buf.len()))
511 }
512
513 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
514 self.poll_flush(cx)
515 }
516
517 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
518 self.poll_close(cx)
519 }
520}
521
522impl Stream for WebSocket {
523 type Item = io::Result<Msg>;
524
525 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
526 match &mut self.inner {
527 Inner::Stream(inner) => inner.receiver.poll_next_unpin(cx),
528 Inner::Standard(inner) => inner.receiver.poll_next_unpin(cx),
529 }
530 }
531}
532
533impl AsyncRead for WebSocket {
534 fn poll_read(
535 mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut tokio::io::ReadBuf,
536 ) -> Poll<io::Result<()>> {
537 while self.read_buf.is_empty() {
538 let Some(msg) = ready!(self.as_mut().poll_next(cx)?) else { return Poll::Ready(Ok(())) };
539 self.read_buf = msg.to_vec();
540 }
541
542 let part = if buf.remaining() < self.read_buf.len() {
543 let rem = self.read_buf.split_off(buf.remaining());
544 mem::replace(&mut self.read_buf, rem)
545 } else {
546 mem::take(&mut self.read_buf)
547 };
548
549 buf.put_slice(&part);
550 Poll::Ready(Ok(()))
551 }
552}
553
554pub struct WebSocketSender {
559 inner: SenderInner,
560 info: Rc<Info>,
561}
562
563enum SenderInner {
564 Stream(stream::Sender),
565 Standard(standard::Sender),
566}
567
568impl fmt::Debug for WebSocketSender {
569 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
570 f.debug_struct("WebSocketSender")
571 .field("url", &self.info.url)
572 .field("protocol", &self.protocol())
573 .field("interface", &self.interface())
574 .finish()
575 }
576}
577
578impl WebSocketSender {
579 pub fn url(&self) -> &str {
581 &self.info.url
582 }
583
584 pub fn protocol(&self) -> &str {
586 &self.info.protocol
587 }
588
589 pub fn interface(&self) -> Interface {
591 self.info.interface
592 }
593
594 pub fn close(self) {
598 self.close_with_reason(CloseCode::NormalClosure, "");
599 }
600
601 #[track_caller]
609 pub fn close_with_reason(self, code: CloseCode, reason: &str) {
610 if !code.is_valid() {
611 panic!("WebSocket close code {code} is invalid");
612 }
613
614 match self.inner {
615 SenderInner::Stream(sender) => sender.close(code.into(), reason),
616 SenderInner::Standard(sender) => sender.close(code.into(), reason),
617 }
618 }
619
620 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
621 match &mut self.inner {
622 SenderInner::Stream(inner) => inner.poll_ready_unpin(cx),
623 SenderInner::Standard(inner) => inner.poll_ready_unpin(cx),
624 }
625 }
626
627 fn start_send(mut self: Pin<&mut Self>, item: &JsValue) -> Result<(), io::Error> {
628 match &mut self.inner {
629 SenderInner::Stream(inner) => inner.start_send_unpin(item),
630 SenderInner::Standard(inner) => inner.start_send_unpin(item),
631 }
632 }
633
634 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
635 match &mut self.inner {
636 SenderInner::Stream(inner) => inner.poll_flush_unpin(cx),
637 SenderInner::Standard(inner) => inner.poll_flush_unpin(cx),
638 }
639 }
640
641 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
642 match &mut self.inner {
643 SenderInner::Stream(inner) => inner.poll_close_unpin(cx),
644 SenderInner::Standard(inner) => inner.poll_close_unpin(cx),
645 }
646 }
647}
648
649impl Sink<&str> for WebSocketSender {
650 type Error = io::Error;
651
652 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
653 self.poll_ready(cx)
654 }
655
656 fn start_send(self: Pin<&mut Self>, item: &str) -> Result<(), Self::Error> {
657 self.start_send(&JsValue::from_str(item))
658 }
659
660 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
661 self.poll_flush(cx)
662 }
663
664 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
665 self.poll_close(cx)
666 }
667}
668
669impl Sink<String> for WebSocketSender {
670 type Error = io::Error;
671
672 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
673 self.poll_ready(cx)
674 }
675
676 fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
677 self.start_send(&JsValue::from_str(&item))
678 }
679
680 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
681 self.poll_flush(cx)
682 }
683
684 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
685 self.poll_close(cx)
686 }
687}
688
689impl Sink<&[u8]> for WebSocketSender {
690 type Error = io::Error;
691
692 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
693 self.poll_ready(cx)
694 }
695
696 fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
697 self.start_send(&Uint8Array::from(item))
698 }
699
700 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
701 self.poll_flush(cx)
702 }
703
704 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
705 self.poll_close(cx)
706 }
707}
708
709impl Sink<Vec<u8>> for WebSocketSender {
710 type Error = io::Error;
711
712 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
713 self.poll_ready(cx)
714 }
715
716 fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
717 self.start_send(&Uint8Array::from(&item[..]))
718 }
719
720 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
721 self.poll_flush(cx)
722 }
723
724 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
725 self.poll_close(cx)
726 }
727}
728
729impl Sink<Msg> for WebSocketSender {
730 type Error = io::Error;
731
732 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
733 self.poll_ready(cx)
734 }
735
736 fn start_send(self: Pin<&mut Self>, item: Msg) -> Result<(), Self::Error> {
737 match item {
738 Msg::Text(text) => self.start_send(&JsValue::from_str(&text)),
739 Msg::Binary(vec) => self.start_send(&Uint8Array::from(&vec[..])),
740 }
741 }
742
743 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
744 self.poll_flush(cx)
745 }
746
747 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
748 self.poll_close(cx)
749 }
750}
751
752impl AsyncWrite for WebSocketSender {
753 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
754 ready!(self.as_mut().poll_ready(cx))?;
755 self.start_send(&Uint8Array::from(buf))?;
756 Poll::Ready(Ok(buf.len()))
757 }
758
759 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
760 self.poll_flush(cx)
761 }
762
763 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
764 self.poll_close(cx)
765 }
766}
767
768pub struct WebSocketReceiver {
773 inner: ReceiverInner,
774 info: Rc<Info>,
775 read_buf: Vec<u8>,
776}
777
778enum ReceiverInner {
779 Stream(stream::Receiver),
780 Standard(standard::Receiver),
781}
782
783impl fmt::Debug for WebSocketReceiver {
784 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
785 f.debug_struct("WebSocketReceiver")
786 .field("url", &self.info.url)
787 .field("protocol", &self.protocol())
788 .field("interface", &self.interface())
789 .finish()
790 }
791}
792
793impl WebSocketReceiver {
794 pub fn url(&self) -> &str {
796 &self.info.url
797 }
798
799 pub fn protocol(&self) -> &str {
801 &self.info.protocol
802 }
803
804 pub fn interface(&self) -> Interface {
806 self.info.interface
807 }
808}
809
810impl Stream for WebSocketReceiver {
811 type Item = io::Result<Msg>;
812
813 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
814 match &mut self.inner {
815 ReceiverInner::Stream(inner) => inner.poll_next_unpin(cx),
816 ReceiverInner::Standard(inner) => inner.poll_next_unpin(cx),
817 }
818 }
819}
820
821impl AsyncRead for WebSocketReceiver {
822 fn poll_read(
823 mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut tokio::io::ReadBuf,
824 ) -> Poll<io::Result<()>> {
825 while self.read_buf.is_empty() {
826 let Some(msg) = ready!(self.as_mut().poll_next(cx)?) else { return Poll::Ready(Ok(())) };
827 self.read_buf = msg.to_vec();
828 }
829
830 let part = if buf.remaining() < self.read_buf.len() {
831 let rem = self.read_buf.split_off(buf.remaining());
832 mem::replace(&mut self.read_buf, rem)
833 } else {
834 mem::take(&mut self.read_buf)
835 };
836
837 buf.put_slice(&part);
838 Poll::Ready(Ok(()))
839 }
840}