ttpkit_http/ws/
mod.rs

1//! WebSockets.
2
3mod frame;
4
5#[cfg(feature = "client")]
6mod client;
7
8#[cfg(feature = "server")]
9mod server;
10
11use std::{
12    borrow::Cow,
13    io,
14    mem::MaybeUninit,
15    pin::Pin,
16    task::{Context, Poll},
17};
18
19use base64::Engine;
20use bytes::{Buf, BufMut, Bytes, BytesMut};
21use futures::{Sink, SinkExt, Stream, StreamExt, ready};
22use sha1::{Digest, Sha1};
23use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24
25use self::frame::{Frame, InvalidFrame};
26
27use crate::connection::Upgraded;
28
29#[cfg(feature = "server")]
30use crate::{Error, server::IncomingRequest};
31
32#[cfg(feature = "client")]
33#[cfg_attr(docsrs, doc(cfg(feature = "client")))]
34pub use self::client::{ClientHandshake, ClientHandshakeBuilder};
35
36#[cfg(feature = "server")]
37#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
38pub use self::server::{FutureServer, ServerHandshake};
39
40/// Create a new WS key.
41pub fn create_key() -> String {
42    base64::prelude::BASE64_STANDARD.encode(&rand::random::<[u8; 16]>()[..])
43}
44
45/// Create WS accept token for a given key.
46pub fn create_accept_token(key: &[u8]) -> String {
47    let suffix = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
48
49    let mut input = Vec::with_capacity(key.len() + suffix.len());
50
51    input.extend_from_slice(key);
52    input.extend_from_slice(suffix);
53
54    let hash = Sha1::digest(&input);
55
56    base64::prelude::BASE64_STANDARD.encode(hash.as_slice())
57}
58
59/// Internal WS error.
60enum InternalError {
61    ProtocolError,
62    InvalidString,
63    MessageSizeExceeded,
64    UnexpectedEof,
65    IO(io::Error),
66}
67
68impl InternalError {
69    /// Create a corresponding close message (if applicable).
70    fn to_close_message(&self) -> Option<CloseMessage> {
71        let status = match self {
72            Self::ProtocolError => CloseMessage::STATUS_PROTOCOL_ERROR,
73            Self::InvalidString => CloseMessage::STATUS_INVALID_DATA,
74            Self::MessageSizeExceeded => CloseMessage::STATUS_TOO_BIG,
75            _ => return None,
76        };
77
78        Some(CloseMessage::new_static(status, ""))
79    }
80}
81
82impl From<InvalidFrame> for InternalError {
83    fn from(_: InvalidFrame) -> Self {
84        Self::ProtocolError
85    }
86}
87
88impl From<io::Error> for InternalError {
89    fn from(err: io::Error) -> Self {
90        Self::IO(err)
91    }
92}
93
94/// WS agent role.
95#[derive(Copy, Clone, Eq, PartialEq)]
96pub enum AgentRole {
97    Client,
98    Server,
99}
100
101/// WS message.
102#[derive(Clone)]
103pub enum Message {
104    Text(String),
105    Data(Bytes),
106    Ping(Bytes),
107    Pong(Bytes),
108    Close(CloseMessage),
109}
110
111impl Message {
112    /// Create a corresponding WS frame.
113    fn into_frame(self) -> Frame {
114        match self {
115            Self::Text(text) => Frame::new(Frame::OPCODE_TEXT, text.into(), true),
116            Self::Data(data) => Frame::new(Frame::OPCODE_BINARY, data, true),
117            Self::Ping(data) => Frame::new(Frame::OPCODE_PING, data, true),
118            Self::Pong(data) => Frame::new(Frame::OPCODE_PONG, data, true),
119            Self::Close(close) => close.into_frame(),
120        }
121    }
122}
123
124impl From<CloseMessage> for Message {
125    #[inline]
126    fn from(close: CloseMessage) -> Self {
127        Self::Close(close)
128    }
129}
130
131/// WS close message.
132#[derive(Clone)]
133pub struct CloseMessage {
134    status: u16,
135    message: Cow<'static, str>,
136}
137
138impl CloseMessage {
139    pub const STATUS_OK: u16 = 1000;
140    pub const STATUS_GOING_AWAY: u16 = 1001;
141    pub const STATUS_PROTOCOL_ERROR: u16 = 1002;
142    pub const STATUS_UNEXPECTED_DATA: u16 = 1003;
143    pub const STATUS_INVALID_DATA: u16 = 1007;
144    pub const STATUS_TOO_BIG: u16 = 1009;
145
146    /// Create a new close message with a given status code and a given text
147    /// message.
148    pub fn new<T>(status: u16, msg: T) -> Self
149    where
150        T: ToString,
151    {
152        Self {
153            status,
154            message: Cow::Owned(msg.to_string()),
155        }
156    }
157
158    /// Create a new close message with a given status code and a given text
159    /// message.
160    #[inline]
161    pub const fn new_static(status: u16, msg: &'static str) -> Self {
162        Self {
163            status,
164            message: Cow::Borrowed(msg),
165        }
166    }
167
168    /// Get the status code.
169    #[inline]
170    pub fn status(&self) -> u16 {
171        self.status
172    }
173
174    /// Get the close message.
175    #[inline]
176    pub fn message(&self) -> &str {
177        &self.message
178    }
179
180    /// Get the corresponding WS frame.
181    fn into_frame(self) -> Frame {
182        let mut data = BytesMut::with_capacity(self.message.len() + 2);
183
184        data.put_u16(self.status);
185        data.extend_from_slice(self.message.as_bytes());
186
187        Frame::new(Frame::OPCODE_CLOSE, data.freeze(), true)
188    }
189}
190
191/// WebSocket.
192pub struct WebSocket {
193    inner: Option<FrameSocket>,
194    current_msg_type: Option<u8>,
195    current_msg_data: Vec<u8>,
196    input_buffer_capacity: usize,
197    closed: bool,
198}
199
200impl WebSocket {
201    /// Create a new WS client.
202    #[cfg(feature = "client")]
203    #[cfg_attr(docsrs, doc(cfg(feature = "client")))]
204    #[inline]
205    pub fn client() -> ClientHandshakeBuilder {
206        ClientHandshake::builder()
207    }
208
209    /// Create a new WS server.
210    #[cfg(feature = "server")]
211    #[cfg_attr(docsrs, doc(cfg(feature = "server")))]
212    #[inline]
213    pub fn server(request: IncomingRequest) -> Result<ServerHandshake, Error> {
214        ServerHandshake::new(request)
215    }
216
217    /// Create a new WS from a given connection.
218    #[inline]
219    pub fn new(upgraded: Upgraded, agent_role: AgentRole, input_buffer_capacity: usize) -> Self {
220        let inner = FrameSocket::new(upgraded, agent_role, input_buffer_capacity);
221
222        Self {
223            inner: Some(inner),
224            current_msg_type: None,
225            current_msg_data: Vec::new(),
226            input_buffer_capacity,
227            closed: false,
228        }
229    }
230
231    /// Process a given WS frame.
232    fn process_frame(&mut self, frame: Frame) -> Result<Option<Message>, InternalError> {
233        let opcode = frame.opcode();
234        let fin = frame.fin();
235        let data = frame.into_payload();
236
237        match opcode {
238            Frame::OPCODE_CONTINUATION => self.process_continuation_frame(&data, fin),
239            Frame::OPCODE_BINARY => self.process_binary_frame(data, fin),
240            Frame::OPCODE_TEXT => self.process_text_frame(data, fin),
241            Frame::OPCODE_PING => self.process_ping_frame(data, fin),
242            Frame::OPCODE_PONG => self.process_pong_frame(data, fin),
243            Frame::OPCODE_CLOSE => self.process_close_frame(data, fin),
244            _ => Err(InternalError::ProtocolError),
245        }
246    }
247
248    /// Process a given WS frame.
249    fn process_continuation_frame(
250        &mut self,
251        data: &[u8],
252        fin: bool,
253    ) -> Result<Option<Message>, InternalError> {
254        let msg_type = self.current_msg_type.ok_or(InternalError::ProtocolError)?;
255
256        if (self.current_msg_data.len() + data.len()) > self.input_buffer_capacity {
257            return Err(InternalError::MessageSizeExceeded);
258        }
259
260        self.current_msg_data.extend(data);
261
262        if !fin {
263            return Ok(None);
264        }
265
266        self.current_msg_type = None;
267
268        let data = Bytes::from(std::mem::take(&mut self.current_msg_data));
269
270        match msg_type {
271            Frame::OPCODE_BINARY => self.process_binary_frame(data, true),
272            Frame::OPCODE_TEXT => self.process_text_frame(data, true),
273            _ => unreachable!(),
274        }
275    }
276
277    /// Process a given WS frame.
278    fn process_binary_frame(
279        &mut self,
280        data: Bytes,
281        fin: bool,
282    ) -> Result<Option<Message>, InternalError> {
283        if self.current_msg_type.is_some() {
284            return Err(InternalError::ProtocolError);
285        }
286
287        if fin {
288            Ok(Some(Message::Data(data)))
289        } else {
290            self.current_msg_type = Some(Frame::OPCODE_BINARY);
291            self.current_msg_data = data.to_vec();
292
293            Ok(None)
294        }
295    }
296
297    /// Process a given WS frame.
298    fn process_text_frame(
299        &mut self,
300        data: Bytes,
301        fin: bool,
302    ) -> Result<Option<Message>, InternalError> {
303        if self.current_msg_type.is_some() {
304            return Err(InternalError::ProtocolError);
305        }
306
307        if fin {
308            let text = std::str::from_utf8(&data)
309                .map_err(|_| InternalError::InvalidString)?
310                .to_string();
311
312            Ok(Some(Message::Text(text)))
313        } else {
314            self.current_msg_type = Some(Frame::OPCODE_TEXT);
315            self.current_msg_data = data.to_vec();
316
317            Ok(None)
318        }
319    }
320
321    /// Process a given WS frame.
322    fn process_ping_frame(
323        &mut self,
324        data: Bytes,
325        fin: bool,
326    ) -> Result<Option<Message>, InternalError> {
327        if !fin {
328            return Err(InternalError::ProtocolError);
329        }
330
331        Ok(Some(Message::Ping(data)))
332    }
333
334    /// Process a given WS frame.
335    fn process_pong_frame(
336        &mut self,
337        data: Bytes,
338        fin: bool,
339    ) -> Result<Option<Message>, InternalError> {
340        if !fin {
341            return Err(InternalError::ProtocolError);
342        }
343
344        Ok(Some(Message::Pong(data)))
345    }
346
347    /// Process a given WS frame.
348    fn process_close_frame(
349        &mut self,
350        mut data: Bytes,
351        fin: bool,
352    ) -> Result<Option<Message>, InternalError> {
353        if !fin {
354            return Err(InternalError::ProtocolError);
355        }
356
357        let status = if data.len() < 2 {
358            // drop any remaining content
359            data.clear();
360
361            1005
362        } else {
363            data.get_u16()
364        };
365
366        let msg = std::str::from_utf8(&data)
367            .map_err(|_| InternalError::InvalidString)?
368            .to_string();
369
370        let msg = CloseMessage::new(status, msg);
371
372        self.closed = true;
373
374        Ok(Some(msg.into()))
375    }
376
377    /// Poll the next WS message.
378    fn poll_next_inner(
379        &mut self,
380        cx: &mut Context<'_>,
381    ) -> Poll<Option<Result<Message, InternalError>>> {
382        loop {
383            if self.closed {
384                return Poll::Ready(None);
385            } else if let Some(inner) = self.inner.as_mut() {
386                if let Poll::Ready(ready) = inner.poll_next_unpin(cx) {
387                    if let Some(frame) = ready.transpose()? {
388                        if let Some(msg) = self.process_frame(frame)? {
389                            return Poll::Ready(Some(Ok(msg)));
390                        }
391                    } else {
392                        return Poll::Ready(None);
393                    }
394                } else {
395                    return Poll::Pending;
396                }
397            } else {
398                return Poll::Ready(None);
399            }
400        }
401    }
402}
403
404impl Stream for WebSocket {
405    type Item = io::Result<Message>;
406
407    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
408        match ready!(self.poll_next_inner(cx)) {
409            Some(Ok(msg)) => Poll::Ready(Some(Ok(msg))),
410            Some(Err(err)) => {
411                if let Some(msg) = err.to_close_message() {
412                    if let Some(mut inner) = self.inner.take() {
413                        tokio::spawn(async move {
414                            let _ = inner.send(msg.into_frame()).await;
415                        });
416                    }
417                }
418
419                let err = match err {
420                    InternalError::UnexpectedEof => io::Error::from(io::ErrorKind::UnexpectedEof),
421                    InternalError::IO(err) => err,
422                    _ => io::Error::from(io::ErrorKind::InvalidData),
423                };
424
425                Poll::Ready(Some(Err(err)))
426            }
427            None => Poll::Ready(None),
428        }
429    }
430}
431
432impl Sink<Message> for WebSocket {
433    type Error = io::Error;
434
435    #[inline]
436    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
437        self.inner
438            .as_mut()
439            .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?
440            .poll_ready_unpin(cx)
441    }
442
443    fn start_send(mut self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> {
444        // making borrow checker happy
445        let this = &mut *self;
446
447        if this.closed {
448            return Err(io::Error::from(io::ErrorKind::BrokenPipe));
449        }
450
451        let inner = this
452            .inner
453            .as_mut()
454            .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?;
455
456        let frame = msg.into_frame();
457
458        this.closed |= frame.opcode() == Frame::OPCODE_CLOSE;
459
460        inner.start_send_unpin(frame)
461    }
462
463    #[inline]
464    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
465        self.inner
466            .as_mut()
467            .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?
468            .poll_flush_unpin(cx)
469    }
470
471    #[inline]
472    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
473        self.inner
474            .as_mut()
475            .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?
476            .poll_close_unpin(cx)
477    }
478}
479
480/// WS frame socket.
481struct FrameSocket {
482    upgraded: Upgraded,
483    agent_role: AgentRole,
484    input_buffer: BytesMut,
485    output_buffer: BytesMut,
486    input_buffer_capacity: usize,
487    sent: usize,
488}
489
490impl FrameSocket {
491    /// Create a new frame socket from a given connection.
492    #[inline]
493    fn new(upgraded: Upgraded, agent_role: AgentRole, input_buffer_capacity: usize) -> Self {
494        Self {
495            upgraded,
496            agent_role,
497            input_buffer: BytesMut::new(),
498            output_buffer: BytesMut::new(),
499            input_buffer_capacity,
500            sent: 0,
501        }
502    }
503}
504
505impl Stream for FrameSocket {
506    type Item = Result<Frame, InternalError>;
507
508    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
509        let mut buffer: [MaybeUninit<u8>; 8192] = unsafe { MaybeUninit::uninit().assume_init() };
510
511        // making borrow checker happy
512        let this = &mut *self;
513
514        loop {
515            if let Some(frame) = Frame::decode(&mut this.input_buffer, this.agent_role)? {
516                return Poll::Ready(Some(Ok(frame)));
517            } else if this.input_buffer.len() >= this.input_buffer_capacity {
518                return Poll::Ready(Some(Err(InternalError::MessageSizeExceeded)));
519            }
520
521            let available = this.input_buffer_capacity - this.input_buffer.len();
522            let read = available.min(buffer.len());
523
524            let mut buffer = ReadBuf::uninit(&mut buffer[..read]);
525
526            let pinned = Pin::new(&mut this.upgraded);
527
528            ready!(pinned.poll_read(cx, &mut buffer))?;
529
530            let filled = buffer.filled();
531
532            if !filled.is_empty() {
533                this.input_buffer.extend_from_slice(filled);
534            } else if this.input_buffer.is_empty() {
535                return Poll::Ready(None);
536            } else {
537                return Poll::Ready(Some(Err(InternalError::UnexpectedEof)));
538            }
539        }
540    }
541}
542
543impl Sink<Frame> for FrameSocket {
544    type Error = io::Error;
545
546    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
547        // making borrow checker happy
548        let this = &mut *self;
549
550        while this.sent < this.output_buffer.len() {
551            let pinned = Pin::new(&mut this.upgraded);
552
553            let len = ready!(pinned.poll_write(cx, &this.output_buffer[this.sent..]))?;
554
555            this.sent += len;
556        }
557
558        this.output_buffer.clear();
559        this.sent = 0;
560
561        Poll::Ready(Ok(()))
562    }
563
564    fn start_send(mut self: Pin<&mut Self>, frame: Frame) -> Result<(), Self::Error> {
565        // making borrow checker happy
566        let this = &mut *self;
567
568        frame.encode(&mut this.output_buffer, this.agent_role);
569
570        Ok(())
571    }
572
573    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
574        ready!(self.poll_ready_unpin(cx))?;
575
576        let pinned = Pin::new(&mut self.upgraded);
577
578        pinned.poll_flush(cx)
579    }
580
581    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
582        ready!(self.poll_ready_unpin(cx))?;
583
584        let pinned = Pin::new(&mut self.upgraded);
585
586        pinned.poll_shutdown(cx)
587    }
588}