1use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::io;
8use std::marker::PhantomData;
9use std::ops::Deref as _;
10use std::pin::Pin;
11use std::str::from_utf8 as str_from_utf8;
12use std::task::Poll;
13use std::time::Duration;
14
15use futures::task::Context;
16use futures::Sink;
17use futures::SinkExt as _;
18use futures::Stream;
19use futures::StreamExt as _;
20
21use tokio::time::interval;
22use tokio::time::Interval;
23use tokio::time::MissedTickBehavior;
24use tokio_tungstenite::tungstenite::Bytes;
25use tokio_tungstenite::tungstenite::Error as WebSocketError;
26use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
27use tokio_tungstenite::tungstenite::Utf8Bytes;
28
29use tracing::debug;
30use tracing::error;
31use tracing::field::debug;
32use tracing::field::DebugValue;
33use tracing::trace;
34
35
36#[derive(Clone, Copy, Debug)]
39enum Ping {
40 NotNeeded,
42 Needed,
45 Pending,
49}
50
51
52#[derive(Debug, PartialEq)]
54pub enum Message {
55 Text(String),
57 Binary(Vec<u8>),
59}
60
61impl From<Message> for WebSocketMessage {
62 fn from(message: Message) -> Self {
63 match message {
64 Message::Text(data) => WebSocketMessage::Text(Utf8Bytes::from(data)),
65 Message::Binary(data) => WebSocketMessage::Binary(Bytes::from(data)),
66 }
67 }
68}
69
70
71#[derive(Debug)]
73enum SendMessageState<M> {
74 Unused,
76 Pending(Option<M>),
82 Flush,
84}
85
86impl<M> SendMessageState<M> {
87 fn advance<S>(&mut self, sink: &mut S, ctx: &mut Context<'_>) -> Result<(), S::Error>
89 where
90 S: Sink<M> + Unpin,
91 M: Debug,
92 {
93 loop {
94 match self {
95 Self::Unused => break Ok(()),
96 Self::Pending(message) => {
97 match sink.poll_ready_unpin(ctx) {
98 Poll::Pending => return Ok(()),
99 Poll::Ready(Ok(())) => (),
100 Poll::Ready(Err(err)) => {
101 *self = Self::Unused;
102 return Err(err)
103 },
104 }
105
106 let message = message.take();
107 *self = Self::Unused;
108 debug!(
109 channel = debug(sink as *const _),
110 send_msg = debug(&message)
111 );
112
113 if let Some(message) = message {
114 sink.start_send_unpin(message)?;
115 *self = Self::Flush;
116 }
117 },
118 Self::Flush => {
119 trace!(channel = debug(sink as *const _), msg = "flushing");
120 match sink.poll_flush_unpin(ctx) {
121 Poll::Pending => break Ok(()),
122 Poll::Ready(Ok(())) => {
123 *self = Self::Unused;
124 },
125 Poll::Ready(Err(err)) => {
126 *self = Self::Unused;
127 break Err(err)
128 },
129 }
130 },
131 }
132 }
133 }
134
135 fn set(&mut self, message: M) {
137 *self = Self::Pending(Some(message))
138 }
139}
140
141fn set_message<S, M>(channel: &S, message_state: &mut SendMessageState<M>, message: M)
147where
148 M: Debug,
149{
150 match message_state {
151 SendMessageState::Unused => (),
152 SendMessageState::Pending(old_message) => {
153 debug!(
154 channel = debug(channel as *const _),
155 send_msg_old = debug(&old_message),
156 send_msg_new = debug(&message),
157 msg = "message overrun; last message has not been sent"
158 );
159 },
160 SendMessageState::Flush => {
161 debug!(
162 channel = debug(channel as *const _),
163 msg = "message overrun; last message has not been flushed"
164 );
165 },
166 }
167
168 message_state.set(message);
169}
170
171
172struct DebugMessage<'m> {
174 message: &'m WebSocketMessage,
175}
176
177impl Debug for DebugMessage<'_> {
178 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
179 match self.message {
180 WebSocketMessage::Binary(data) => {
183 if let Ok(s) = str_from_utf8(data) {
184 f.debug_tuple("Binary").field(&s).finish()
185 } else {
186 f.debug_tuple("Binary").field(&data.deref()).finish()
187 }
188 },
189 WebSocketMessage::Ping(data) => f.debug_tuple("Ping").field(&data.deref()).finish(),
190 WebSocketMessage::Pong(data) => f.debug_tuple("Pong").field(&data.deref()).finish(),
191 _ => Debug::fmt(self.message, f),
192 }
193 }
194}
195
196fn debug_message(message: &WebSocketMessage) -> DebugValue<DebugMessage<'_>> {
199 debug(DebugMessage { message })
200}
201
202
203#[derive(Debug)]
206struct Pinger {
207 ping: SendMessageState<WebSocketMessage>,
209 next_ping: Interval,
212 ping_state: Ping,
215}
216
217impl Pinger {
218 fn new(ping_interval: Duration) -> Self {
221 let mut next_ping = interval(ping_interval);
222 let () = next_ping.set_missed_tick_behavior(MissedTickBehavior::Delay);
226
227 Self {
228 ping: SendMessageState::Unused,
229 next_ping,
230 ping_state: Ping::NotNeeded,
231 }
232 }
233
234 #[allow(clippy::result_large_err)]
236 fn advance<S>(&mut self, sink: &mut S, ctx: &mut Context<'_>) -> Result<(), S::Error>
237 where
238 S: Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
239 {
240 let () = self.ping.advance(sink, ctx)?;
241 let mut result = Ok(());
242
243 loop {
246 match self.next_ping.poll_tick(ctx) {
247 Poll::Ready(_instant) => {
248 self.ping_state = match self.ping_state {
253 Ping::NotNeeded => {
254 trace!(
255 channel = debug(sink as *const _),
256 msg = "skipping ping due to activity"
257 );
258 Ping::Needed
262 },
263 Ping::Needed => {
264 let message = WebSocketMessage::Ping(Bytes::new());
269 let () = set_message(sink, &mut self.ping, message);
270
271 self.ping.advance(sink, ctx)?;
272 Ping::Pending
273 },
274 Ping::Pending => {
275 error!(
276 channel = debug(sink as *const _),
277 msg = "server failed to respond to pings"
278 );
279
280 let err = WebSocketError::Io(io::Error::new(
281 io::ErrorKind::TimedOut,
282 "server failed to respond to pings",
283 ));
284 result = Err(err);
285
286 Ping::Needed
291 },
292 };
293 },
294 Poll::Pending => break result,
295 }
296 }
297 }
298
299 fn activity(&mut self) {
303 self.ping_state = Ping::NotNeeded;
304 }
305}
306
307
308#[derive(Debug)]
310pub struct Builder<S> {
311 ping_interval: Option<Duration>,
314 _phantom: PhantomData<S>,
316}
317
318impl<S> Builder<S> {
319 pub fn set_ping_interval(mut self, interval: Option<Duration>) -> Builder<S> {
321 self.ping_interval = interval;
322 self
323 }
324
325 pub fn build(self, channel: S) -> Wrapper<S> {
327 Wrapper {
328 inner: channel,
329 ping: self.ping_interval.map(Pinger::new),
330 }
331 }
332}
333
334impl<S> Default for Builder<S> {
335 fn default() -> Self {
336 Self {
337 ping_interval: Some(Duration::from_secs(30)),
338 _phantom: PhantomData,
339 }
340 }
341}
342
343
344#[derive(Debug)]
348#[must_use = "streams do nothing unless polled"]
349pub struct Wrapper<S> {
350 inner: S,
352 ping: Option<Pinger>,
354}
355
356impl<S> Wrapper<S> {
357 pub fn builder() -> Builder<S> {
359 Builder::default()
360 }
361}
362
363impl<S> Stream for Wrapper<S>
364where
365 S: Sink<WebSocketMessage, Error = WebSocketError>
366 + Stream<Item = Result<WebSocketMessage, WebSocketError>>
367 + Unpin,
368{
369 type Item = Result<Message, S::Error>;
370
371 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
372 let this = Pin::get_mut(self);
373
374 if let Some(ping) = &mut this.ping {
375 if let Err(err) = ping.advance(&mut this.inner, ctx) {
376 return Poll::Ready(Some(Err(err)))
377 }
378 }
379
380 loop {
381 match this.inner.poll_next_unpin(ctx) {
382 Poll::Pending => {
383 break Poll::Pending
386 },
387 Poll::Ready(None) => {
388 break Poll::Ready(None)
390 },
391 Poll::Ready(Some(Err(err))) => break Poll::Ready(Some(Err(err))),
392 Poll::Ready(Some(Ok(message))) => {
393 debug!(
394 channel = debug(&this.inner as *const _),
395 recv_msg = debug_message(&message)
396 );
397 let () = this.ping.as_mut().map(Pinger::activity).unwrap_or(());
398
399 match message {
400 WebSocketMessage::Text(data) => {
401 break Poll::Ready(Some(Ok(Message::Text(data.to_string()))))
402 },
403 WebSocketMessage::Binary(data) => {
404 break Poll::Ready(Some(Ok(Message::Binary(data.to_vec()))))
405 },
406 WebSocketMessage::Ping(_) => {
407 },
410 WebSocketMessage::Pong(_) => {
411 },
414 WebSocketMessage::Close(_) => {
415 break Poll::Ready(None)
420 },
421 WebSocketMessage::Frame(_) => {
422 },
425 }
426 },
427 }
428 }
429 }
430}
431
432impl<S> Sink<Message> for Wrapper<S>
433where
434 S: Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
435{
436 type Error = S::Error;
437
438 fn poll_ready(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
439 self.inner.poll_ready_unpin(ctx)
440 }
441
442 fn start_send(mut self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
443 let message = message.into();
444 debug!(
445 channel = debug(&self.inner as *const _),
446 send_msg = debug_message(&message)
447 );
448 self.inner.start_send_unpin(message)
449 }
450
451 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
452 trace!(channel = debug(&self.inner as *const _), msg = "flushing");
453 self.inner.poll_flush_unpin(ctx)
454 }
455
456 fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
457 self.inner.poll_close_unpin(ctx)
458 }
459}
460
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 use std::future::Future;
467 use std::sync::atomic::AtomicUsize;
468 use std::sync::atomic::Ordering;
469 use std::sync::Arc;
470
471 use futures::future::ready;
472 use futures::TryStreamExt as _;
473
474 use rand::seq::IteratorRandom as _;
475 use rand::thread_rng;
476 use rand::Rng as _;
477
478 use test_log::test;
479
480 use tokio::time::pause;
481 use tokio::time::sleep;
482 use tokio::time::timeout;
483
484 use tokio_tungstenite::connect_async;
485 use tokio_tungstenite::tungstenite::error::ProtocolError;
486
487 use url::Url;
488
489 use crate::test::mock_server;
490 use crate::test::WebSocketStream;
491
492
493 #[test]
496 fn debug_websocket_message() {
497 let message = WebSocketMessage::Binary(Bytes::from(b"this is a test".as_slice()));
498 let expected = r#"Binary("this is a test")"#;
499 assert_eq!(format!("{:?}", debug_message(&message)), expected);
500
501 let message = WebSocketMessage::Binary(Bytes::from([0xf0, 0x90, 0x80].as_slice()));
503 let expected = r#"Binary([240, 144, 128])"#;
504 assert_eq!(format!("{:?}", debug_message(&message)), expected);
505
506 let message = WebSocketMessage::Ping(Bytes::new());
507 let expected = r#"Ping([])"#;
508 assert_eq!(format!("{:?}", debug_message(&message)), expected);
509 }
510
511
512 async fn serve_and_connect_with_builder<F, R>(
516 builder: Builder<WebSocketStream>,
517 f: F,
518 ) -> Wrapper<WebSocketStream>
519 where
520 F: FnOnce(WebSocketStream) -> R + Send + Sync + 'static,
521 R: Future<Output = Result<(), WebSocketError>> + Send + Sync + 'static,
522 {
523 let addr = mock_server(f).await;
524 let url = Url::parse(&format!("ws://{}", addr)).unwrap();
525
526 let (stream, _) = connect_async(url).await.unwrap();
527 builder.build(stream)
528 }
529
530 async fn serve_and_connect<F, R>(f: F) -> Wrapper<WebSocketStream>
534 where
535 F: FnOnce(WebSocketStream) -> R + Send + Sync + 'static,
536 R: Future<Output = Result<(), WebSocketError>> + Send + Sync + 'static,
537 {
538 let ping = Some(Duration::from_millis(10));
539 let builder = Wrapper::builder().set_ping_interval(ping);
540 serve_and_connect_with_builder(builder, f).await
541 }
542
543 #[test(tokio::test)]
546 async fn no_messages() {
547 async fn test(_stream: WebSocketStream) -> Result<(), WebSocketError> {
548 Ok(())
549 }
550
551 let err = serve_and_connect(test)
552 .await
553 .try_for_each(|_| ready(Ok(())))
554 .await
555 .unwrap_err();
556
557 match err {
558 WebSocketError::Protocol(ProtocolError::ResetWithoutClosingHandshake) => (),
559 e => panic!("received unexpected error: {}", e),
560 }
561 }
562
563 #[test(tokio::test)]
566 async fn direct_close() {
567 async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
568 stream.send(WebSocketMessage::Close(None)).await?;
570 Ok(())
571 }
572
573 serve_and_connect(test)
574 .await
575 .try_for_each(|_| ready(Ok(())))
576 .await
577 .unwrap();
578 }
579
580 #[test(tokio::test)]
583 async fn ping_pong() {
584 async fn test(stream: WebSocketStream) -> Result<(), WebSocketError> {
585 let mut stream = stream.fuse();
586
587 stream.send(WebSocketMessage::Ping(Bytes::new())).await?;
589 assert_eq!(
591 stream.next().await.unwrap()?,
592 WebSocketMessage::Pong(Bytes::new()),
593 );
594
595 let future = stream.select_next_some();
596 assert!(timeout(Duration::from_millis(20), future).await.is_err());
597
598 stream.send(WebSocketMessage::Close(None)).await?;
599 Ok(())
600 }
601
602 let builder = Wrapper::builder().set_ping_interval(None);
603 serve_and_connect_with_builder(builder, test)
604 .await
605 .try_for_each(|_| ready(Ok(())))
606 .await
607 .unwrap();
608 }
609
610 #[test(tokio::test)]
612 async fn pings_are_sent() {
613 async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
614 for _ in 0..2 {
618 assert!(matches!(
619 stream.next().await.unwrap()?,
620 WebSocketMessage::Ping(_)
621 ));
622 }
623
624 stream.send(WebSocketMessage::Close(None)).await?;
625 Ok(())
626 }
627
628 serve_and_connect(test)
629 .await
630 .try_for_each(|_| ready(Ok(())))
631 .await
632 .unwrap();
633 }
634
635 #[test(tokio::test)]
638 async fn no_ping_bursts() {
639 let counter = Arc::new(AtomicUsize::new(0));
640 let clone = counter.clone();
641
642 let test = |stream: WebSocketStream| async move {
643 let mut stream = stream.fuse();
644
645 loop {
646 let msg = stream.next().await.unwrap().unwrap();
647 if let WebSocketMessage::Ping(_) = msg {
648 let _ = clone.fetch_add(1, Ordering::Relaxed);
649
650 } else {
653 panic!("received unexpected message: {msg:?}")
654 }
655 }
656 };
657
658 let () = pause();
661
662 let wrapper = serve_and_connect(test).await;
663 let () = sleep(Duration::from_secs(10)).await;
667
668 let future = wrapper.for_each(|result| {
669 assert!(result.is_ok(), "{result:?}");
670 ready(())
671 });
672
673 assert!(timeout(Duration::from_millis(15), future).await.is_err());
676
677 assert_eq!(counter.load(Ordering::Relaxed), 1);
679 }
680
681 #[test(tokio::test)]
684 async fn no_pings_are_sent_when_disabled() {
685 async fn test(stream: WebSocketStream) -> Result<(), WebSocketError> {
686 let mut stream = stream.fuse();
687 let future = stream.select_next_some();
688 assert!(timeout(Duration::from_millis(20), future).await.is_err());
689
690 stream.send(WebSocketMessage::Close(None)).await?;
691 Ok(())
692 }
693
694 let builder = Wrapper::builder().set_ping_interval(None);
695 serve_and_connect_with_builder(builder, test)
696 .await
697 .try_for_each(|_| ready(Ok(())))
698 .await
699 .unwrap();
700 }
701
702 #[test(tokio::test)]
705 async fn no_pong_response() {
706 async fn test(_stream: WebSocketStream) -> Result<(), WebSocketError> {
707 sleep(Duration::from_secs(10)).await;
708 Ok(())
709 }
710
711 let mut stream = serve_and_connect(test).await;
712
713 for _ in 0..5 {
716 let err = stream.next().await.unwrap().unwrap_err();
717 match err {
718 WebSocketError::Io(err) => {
719 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
720 assert_eq!(err.to_string(), "server failed to respond to pings");
721 },
722 _ => panic!("Received unexpected error: {err:?}"),
723 }
724 }
725 }
726
727 #[test(tokio::test)]
729 async fn send_messages() {
730 async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
731 stream
732 .send(WebSocketMessage::Text(Utf8Bytes::from_static("42")))
733 .await?;
734 stream.send(WebSocketMessage::Pong(Bytes::new())).await?;
735 stream
736 .send(WebSocketMessage::Text(Utf8Bytes::from_static("43")))
737 .await?;
738 stream.send(WebSocketMessage::Close(None)).await?;
739 Ok(())
740 }
741
742 let stream = serve_and_connect(test).await;
743 let messages = stream.try_collect::<Vec<_>>().await.unwrap();
744 assert_eq!(
745 messages,
746 vec![
747 Message::Text("42".to_string()),
748 Message::Text("43".to_string())
749 ]
750 );
751 }
752
753 #[test(tokio::test)]
756 #[ignore = "stress test; test takes a long time"]
757 async fn stress_stream() {
758 async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
759 fn random_buf() -> Bytes {
760 let len = (0..32).choose(&mut thread_rng()).unwrap();
761 let mut vec = Vec::new();
762 vec.extend((0..len).map(|_| thread_rng().gen::<u8>()));
763 Bytes::from(vec)
764 }
765
766 fn random_str() -> Utf8Bytes {
767 let len = (0..32).choose(&mut thread_rng()).unwrap();
768 let mut string = String::new();
769 string.extend((0..len).map(|_| thread_rng().gen::<char>()));
770 Utf8Bytes::from(string)
771 }
772
773 for _ in 0..50000 {
774 let message = match (0..5).choose(&mut thread_rng()).unwrap() {
775 0 => WebSocketMessage::Pong(random_buf()),
776 i => {
784 if i & 0x1 == 0 {
785 WebSocketMessage::Text(random_str())
786 } else {
787 WebSocketMessage::Binary(random_buf())
788 }
789 },
790 };
791
792 stream.send(message).await?;
793 }
794
795 stream.send(WebSocketMessage::Close(None)).await?;
796 Ok(())
797 }
798
799 serve_and_connect(test)
800 .await
801 .try_for_each(|_| ready(Ok(())))
802 .await
803 .unwrap();
804 }
805}