tokio_websockets/proto/
stream.rs

1//! Frame aggregating abstraction over the low-level [`super::codec`]
2//! implementation that provides [`futures_sink::Sink`] and
3//! [`futures_core::Stream`] implementations that take [`Message`] as a
4//! parameter.
5use std::{
6    collections::VecDeque,
7    io::{self, IoSlice},
8    mem::{replace, take},
9    pin::Pin,
10    task::{ready, Context, Poll, Waker},
11};
12
13use bytes::{Buf, BytesMut};
14use futures_core::Stream;
15use futures_sink::Sink;
16use tokio::io::{AsyncRead, AsyncWrite};
17use tokio_util::{codec::FramedRead, io::poll_write_buf};
18
19#[cfg(any(feature = "client", feature = "server"))]
20use super::types::Role;
21use super::{
22    codec::WebSocketProtocol,
23    types::{Frame, Message, OpCode, Payload, StreamState},
24    Config, Limits,
25};
26use crate::{CloseCode, Error};
27
28/// Helper struct for storing a frame header, the header size and payload.
29#[derive(Debug)]
30struct EncodedFrame {
31    /// Encoded frame header and mask.
32    header: [u8; 14],
33    /// Potentially masked message payload, ready for writing to the I/O.
34    payload: Payload,
35}
36
37impl EncodedFrame {
38    /// Whether or not this frame is masked.
39    #[inline]
40    fn is_masked(&self) -> bool {
41        self.header[1] >> 7 != 0
42    }
43
44    /// Returns the length of the combined header and mask in bytes.
45    #[inline]
46    fn header_len(&self) -> usize {
47        let mask_bytes = if self.is_masked() { 4 } else { 0 };
48        match self.header[1] & 127 {
49            127 => 10 + mask_bytes,
50            126 => 4 + mask_bytes,
51            _ => 2 + mask_bytes,
52        }
53    }
54
55    /// Total length of the frame.
56    fn len(&self) -> usize {
57        self.header_len() + self.payload.len()
58    }
59}
60
61/// Queued up frames that are being sent.
62#[derive(Debug)]
63struct FrameQueue {
64    /// Queue of outgoing frames to send. Some parts of the first item may have
65    /// been sent already.
66    queue: VecDeque<EncodedFrame>,
67    /// Amount of partial bytes written of the first frame in the queue.
68    bytes_written: usize,
69    /// Total amount of bytes remaining to be sent in the frame queue.
70    pending_bytes: usize,
71}
72
73impl FrameQueue {
74    /// Creates a new, empty [`FrameQueue`].
75    #[cfg(any(feature = "client", feature = "server"))]
76    fn new() -> Self {
77        Self {
78            queue: VecDeque::with_capacity(1),
79            bytes_written: 0,
80            pending_bytes: 0,
81        }
82    }
83
84    /// Queue a frame to be sent.
85    fn push(&mut self, item: EncodedFrame) {
86        self.pending_bytes += item.len();
87        self.queue.push_back(item);
88    }
89}
90
91impl Buf for FrameQueue {
92    fn remaining(&self) -> usize {
93        self.pending_bytes
94    }
95
96    fn chunk(&self) -> &[u8] {
97        if let Some(frame) = self.queue.front() {
98            if self.bytes_written >= frame.header_len() {
99                unsafe {
100                    frame
101                        .payload
102                        .get_unchecked(self.bytes_written - frame.header_len()..)
103                }
104            } else {
105                &frame.header[self.bytes_written..frame.header_len()]
106            }
107        } else {
108            &[]
109        }
110    }
111
112    fn advance(&mut self, mut cnt: usize) {
113        self.pending_bytes -= cnt;
114        cnt += self.bytes_written;
115
116        while cnt > 0 {
117            let item = self
118                .queue
119                .front()
120                .expect("advance called with too long count");
121            let item_len = item.len();
122
123            if cnt >= item_len {
124                self.queue.pop_front();
125                self.bytes_written = 0;
126                cnt -= item_len;
127            } else {
128                self.bytes_written = cnt;
129                return;
130            }
131        }
132    }
133
134    fn chunks_vectored<'a>(&'a self, dst: &mut [io::IoSlice<'a>]) -> usize {
135        let mut n = 0;
136        for (idx, frame) in self.queue.iter().enumerate() {
137            if n >= dst.len() {
138                break;
139            }
140
141            if idx == 0 {
142                if frame.header_len() > self.bytes_written {
143                    dst[n] = IoSlice::new(&frame.header[self.bytes_written..frame.header_len()]);
144                    n += 1;
145                }
146
147                if !frame.payload.is_empty() && n < dst.len() {
148                    dst[n] = IoSlice::new(unsafe {
149                        frame
150                            .payload
151                            .get_unchecked(self.bytes_written.saturating_sub(frame.header_len())..)
152                    });
153                    n += 1;
154                }
155            } else {
156                dst[n] = IoSlice::new(&frame.header[..frame.header_len()]);
157                n += 1;
158                if !frame.payload.is_empty() && n < dst.len() {
159                    dst[n] = IoSlice::new(&frame.payload);
160                    n += 1;
161                }
162            }
163        }
164
165        n
166    }
167}
168
169/// A WebSocket stream that full messages can be read from and written to.
170///
171/// The stream implements [`futures_sink::Sink`] and [`futures_core::Stream`].
172///
173/// You must use a [`ClientBuilder`] or [`ServerBuilder`] to
174/// obtain a WebSocket stream.
175///
176/// For usage examples, see the top level crate documentation, which showcases a
177/// simple echo server and client.
178///
179/// [`ClientBuilder`]: crate::ClientBuilder
180/// [`ServerBuilder`]: crate::ServerBuilder
181#[allow(clippy::module_name_repetitions)]
182#[derive(Debug)]
183pub struct WebSocketStream<T> {
184    /// The underlying stream using the [`WebSocketProtocol`] to read and write
185    /// full frames.
186    inner: FramedRead<T, WebSocketProtocol>,
187
188    /// Configuration for the stream.
189    config: Config,
190
191    /// The [`StreamState`] of the current stream.
192    state: StreamState,
193
194    /// Payload of the full message that is being assembled.
195    partial_payload: BytesMut,
196    /// Opcode of the full message that is being assembled.
197    partial_opcode: OpCode,
198
199    /// Buffer that outgoing frame headers are formatted into.
200    header_buf: [u8; 14],
201
202    /// Queue of outgoing frames to send.
203    frame_queue: FrameQueue,
204
205    /// Waker used for currently actively polling
206    /// [`WebSocketStream::poll_flush`] until completion.
207    flushing_waker: Option<Waker>,
208}
209
210impl<T> WebSocketStream<T>
211where
212    T: AsyncRead + AsyncWrite + Unpin,
213{
214    /// Create a new [`WebSocketStream`] from a raw stream.
215    #[cfg(any(feature = "client", feature = "server"))]
216    pub(crate) fn from_raw_stream(stream: T, role: Role, config: Config, limits: Limits) -> Self {
217        Self {
218            inner: FramedRead::new(stream, WebSocketProtocol::new(role, limits)),
219            config,
220            state: StreamState::Active,
221            partial_payload: BytesMut::new(),
222            partial_opcode: OpCode::Continuation,
223            header_buf: [0; 14],
224            frame_queue: FrameQueue::new(),
225            flushing_waker: None,
226        }
227    }
228
229    /// Create a new [`WebSocketStream`] from an existing [`FramedRead`]. This
230    /// allows for reusing the internal buffer of the [`FramedRead`] object.
231    #[cfg(any(feature = "client", feature = "server"))]
232    pub(crate) fn from_framed<U>(
233        framed: FramedRead<T, U>,
234        role: Role,
235        config: Config,
236        limits: Limits,
237    ) -> Self {
238        Self {
239            inner: framed.map_decoder(|_| WebSocketProtocol::new(role, limits)),
240            config,
241            state: StreamState::Active,
242            partial_payload: BytesMut::new(),
243            partial_opcode: OpCode::Continuation,
244            header_buf: [0; 14],
245            frame_queue: FrameQueue::new(),
246            flushing_waker: None,
247        }
248    }
249
250    /// Returns a reference to the underlying I/O stream wrapped by this stream.
251    ///
252    /// Care should be taken not to tamper with the stream of data to avoid
253    /// corrupting the stream of frames.
254    pub fn get_ref(&self) -> &T {
255        self.inner.get_ref()
256    }
257
258    /// Returns a mutable reference to the underlying I/O stream wrapped by this
259    /// stream.
260    ///
261    /// Care should be taken not to tamper with the stream of data to avoid
262    /// corrupting the stream of frames.
263    pub fn get_mut(&mut self) -> &mut T {
264        self.inner.get_mut()
265    }
266
267    /// Returns a reference to the inner websocket limits.
268    pub fn limits(&self) -> &Limits {
269        &self.inner.decoder().limits
270    }
271
272    /// Returns a mutable reference to the inner websocket limits.
273    pub fn limits_mut(&mut self) -> &mut Limits {
274        &mut self.inner.decoder_mut().limits
275    }
276
277    /// Attempt to pull out the next frame from the [`Framed`] this stream and
278    /// from that update the stream's internal state.
279    ///
280    /// # Errors
281    ///
282    /// This method returns an [`Error`] if reading from the stream fails or a
283    /// protocol violation is encountered.
284    fn poll_next_frame(
285        mut self: Pin<&mut Self>,
286        cx: &mut Context<'_>,
287    ) -> Poll<Option<Result<Frame, Error>>> {
288        // In the case of Active or ClosedByUs, we want to receive more messages from
289        // the remote. In the case of ClosedByPeer, we have to flush to make sure our
290        // close acknowledge goes through.
291        if self.state == StreamState::CloseAcknowledged {
292            return Poll::Ready(None);
293        } else if self.state == StreamState::ClosedByPeer {
294            ready!(self.as_mut().poll_flush(cx))?;
295            self.state = StreamState::CloseAcknowledged;
296            return Poll::Ready(None);
297        }
298
299        // If there are pending items, try to flush the sink.
300        // Futures only store a single waker. If we use poll_flush(cx) here, the stored
301        // waker (i.e. usually that of the write task) is replaced with our waker (i.e.
302        // that of the read task) and our write task may never get woken up again. We
303        // circumvent this by not calling poll_flush at all if poll_flush is polled by
304        // another task at the moment.
305        if self.frame_queue.has_remaining() {
306            let waker = self.flushing_waker.clone();
307            _ = self.as_mut().poll_flush(&mut Context::from_waker(
308                waker.as_ref().unwrap_or(cx.waker()),
309            ))?;
310        }
311
312        let frame = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
313            Some(Ok(frame)) => frame,
314            Some(Err(e)) => {
315                if matches!(e, Error::Io(_)) || self.state == StreamState::ClosedByUs {
316                    self.state = StreamState::CloseAcknowledged;
317                } else {
318                    self.state = StreamState::ClosedByPeer;
319
320                    match &e {
321                        Error::Protocol(e) => self.queue_frame(Frame::from(e)),
322                        Error::PayloadTooLong { max_len, .. } => self.queue_frame(
323                            Message::close(
324                                Some(CloseCode::MESSAGE_TOO_BIG),
325                                &format!("max length: {max_len}"),
326                            )
327                            .into(),
328                        ),
329                        _ => {}
330                    }
331                }
332                return Poll::Ready(Some(Err(e)));
333            }
334            None => return Poll::Ready(None),
335        };
336
337        match frame.opcode {
338            OpCode::Close => match self.state {
339                StreamState::Active => {
340                    self.state = StreamState::ClosedByPeer;
341
342                    let mut frame = frame.clone();
343                    frame.payload.truncate(2);
344
345                    self.queue_frame(frame);
346                }
347                StreamState::ClosedByPeer | StreamState::CloseAcknowledged => {
348                    debug_assert!(false, "unexpected StreamState");
349                }
350                StreamState::ClosedByUs => {
351                    self.state = StreamState::CloseAcknowledged;
352                }
353            },
354            OpCode::Ping if self.state == StreamState::Active => {
355                let mut frame = frame.clone();
356                frame.opcode = OpCode::Pong;
357
358                self.queue_frame(frame);
359            }
360            _ => {}
361        }
362
363        Poll::Ready(Some(Ok(frame)))
364    }
365
366    /// Masks and queues a frame for sending when [`poll_flush`] gets called.
367    fn queue_frame(
368        &mut self,
369        #[cfg_attr(not(feature = "client"), allow(unused_mut))] mut frame: Frame,
370    ) {
371        if frame.opcode == OpCode::Close && self.state != StreamState::ClosedByPeer {
372            self.state = StreamState::ClosedByUs;
373        }
374
375        #[cfg_attr(not(feature = "client"), allow(unused_variables))]
376        let mask = frame.encode(&mut self.header_buf);
377
378        #[cfg(feature = "client")]
379        {
380            if self.inner.decoder().role == Role::Client {
381                let mut payload = BytesMut::from(frame.payload);
382                crate::rand::get_mask(mask);
383                // mask::frame will mutate the mask in-place, but we want to send the original
384                // mask. This is essentially a u32, so copying it is cheap and easier than
385                // special-casing this in the masking implementation.
386                // &mut *mask won't work, the compiler will optimize the deref/copy away
387                let mut mask_copy = *mask;
388                crate::mask::frame(&mut mask_copy, &mut payload);
389                frame.payload = Payload::from(payload);
390                self.header_buf[1] |= 1 << 7;
391            }
392        }
393
394        let item = EncodedFrame {
395            header: self.header_buf,
396            payload: frame.payload,
397        };
398        self.frame_queue.push(item);
399    }
400
401    /// Sets the waker that is currently flushing to a new one and does nothing
402    /// if the waker is the same.
403    fn set_flushing_waker(&mut self, waker: &Waker) {
404        if !self
405            .flushing_waker
406            .as_ref()
407            .is_some_and(|w| w.will_wake(waker))
408        {
409            self.flushing_waker = Some(waker.clone());
410        }
411    }
412}
413
414impl<T> Stream for WebSocketStream<T>
415where
416    T: AsyncRead + AsyncWrite + Unpin,
417{
418    type Item = Result<Message, Error>;
419
420    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
421        let max_len = self.inner.decoder().limits.max_payload_len;
422
423        loop {
424            let (opcode, payload, fin) = match ready!(self.as_mut().poll_next_frame(cx)?) {
425                Some(frame) => (frame.opcode, frame.payload, frame.is_final),
426                None => return Poll::Ready(None),
427            };
428            let len = self.partial_payload.len() + payload.len();
429
430            if opcode != OpCode::Continuation {
431                if fin {
432                    return Poll::Ready(Some(Ok(Message { opcode, payload })));
433                }
434                self.partial_opcode = opcode;
435                self.partial_payload = BytesMut::from(payload);
436            } else if len > max_len {
437                return Poll::Ready(Some(Err(Error::PayloadTooLong { len, max_len })));
438            } else {
439                self.partial_payload.extend_from_slice(&payload);
440            }
441
442            if fin {
443                break;
444            }
445        }
446
447        let opcode = replace(&mut self.partial_opcode, OpCode::Continuation);
448        let mut payload = Payload::from(take(&mut self.partial_payload));
449        payload.set_utf8_validated(opcode == OpCode::Text);
450
451        Poll::Ready(Some(Ok(Message { opcode, payload })))
452    }
453}
454
455// The tokio-util implementation of a sink uses a buffer which start_send
456// appends to and poll_flush tries to write from. This makes sense, but comes
457// with a hefty performance penalty when sending large payloads, since this adds
458// a memmove from the payload to the buffer. We completely avoid that overhead
459// by storing messages in a deque.
460impl<T> Sink<Message> for WebSocketStream<T>
461where
462    T: AsyncRead + AsyncWrite + Unpin,
463{
464    type Error = Error;
465
466    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
467        // tokio-util calls poll_flush when more than 8096 bytes are pending, otherwise
468        // it returns Ready. We will just replicate that behavior
469        if self.frame_queue.remaining() >= self.config.flush_threshold {
470            self.as_mut().poll_flush(cx)
471        } else {
472            Poll::Ready(Ok(()))
473        }
474    }
475
476    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
477        if self.state != StreamState::Active {
478            return Err(Error::AlreadyClosed);
479        }
480
481        if item.opcode.is_control() || item.payload.len() <= self.config.frame_size {
482            let frame: Frame = item.into();
483            self.queue_frame(frame);
484        } else {
485            // Chunk the message into frames
486            for frame in item.into_frames(self.config.frame_size) {
487                self.queue_frame(frame);
488            }
489        }
490
491        Ok(())
492    }
493
494    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
495        // Borrow checker hacks... It needs this to understand that we can separately
496        // borrow the fields of the struct mutably
497        let this = self.get_mut();
498        let frame_queue = &mut this.frame_queue;
499        let io = this.inner.get_mut();
500        let flushing_waker = &mut this.flushing_waker;
501
502        while frame_queue.has_remaining() {
503            let n = match poll_write_buf(Pin::new(io), cx, frame_queue) {
504                Poll::Ready(Ok(n)) => n,
505                Poll::Ready(Err(e)) => {
506                    *flushing_waker = None;
507                    this.state = StreamState::CloseAcknowledged;
508                    return Poll::Ready(Err(Error::Io(e)));
509                }
510                Poll::Pending => {
511                    this.set_flushing_waker(cx.waker());
512                    return Poll::Pending;
513                }
514            };
515
516            if n == 0 {
517                *flushing_waker = None;
518                this.state = StreamState::CloseAcknowledged;
519                return Poll::Ready(Err(Error::Io(io::ErrorKind::WriteZero.into())));
520            }
521        }
522
523        match Pin::new(io).poll_flush(cx) {
524            Poll::Ready(Ok(())) => {
525                *flushing_waker = None;
526                Poll::Ready(Ok(()))
527            }
528            Poll::Ready(Err(e)) => {
529                *flushing_waker = None;
530                this.state = StreamState::CloseAcknowledged;
531                Poll::Ready(Err(Error::Io(e)))
532            }
533            Poll::Pending => {
534                this.set_flushing_waker(cx.waker());
535                Poll::Pending
536            }
537        }
538    }
539
540    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
541        if self.state == StreamState::Active {
542            self.queue_frame(Frame::DEFAULT_CLOSE);
543        }
544        while ready!(self.as_mut().poll_next(cx)).is_some() {}
545
546        ready!(self.as_mut().poll_flush(cx))?;
547        Pin::new(self.inner.get_mut())
548            .poll_shutdown(cx)
549            .map_err(Error::Io)
550    }
551}