rusty_socket/
frame.rs

1//! WebSocket frame implementation.
2
3use smol::io::{AsyncRead, AsyncReadExt};
4use std::convert::TryFrom;
5
6use crate::{Error, Result};
7
8/// Represents the opcode of a WebSocket frame.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum OpCode {
11    /// Indicates a continuation frame.
12    Continuation,
13    /// Indicates a text frame.
14    Text,
15    /// Indicates a binary frame.
16    Binary,
17    /// Indicates a close frame.
18    Close,
19    /// Indicates a ping frame.
20    Ping,
21    /// Indicates a pong frame.
22    Pong,
23}
24
25impl TryFrom<u8> for OpCode {
26    type Error = Error;
27
28    /// Tries to convert a u8 value into an OpCode.
29    ///
30    /// # Errors
31    ///
32    /// Returns an error if the value is not a valid OpCode.
33    fn try_from(value: u8) -> Result<Self> {
34        match value {
35            0 => Ok(OpCode::Continuation),
36            1 => Ok(OpCode::Text),
37            2 => Ok(OpCode::Binary),
38            8 => Ok(OpCode::Close),
39            9 => Ok(OpCode::Ping),
40            10 => Ok(OpCode::Pong),
41            _ => Err(Error::Protocol(format!("Invalid OpCode: {}", value))),
42        }
43    }
44}
45
46/// Represents a WebSocket frame.
47#[derive(Debug)]
48pub struct Frame {
49    /// Indicates if this is the final fragment in a message.
50    pub fin: bool,
51    /// First reserved bit.
52    pub rsv1: bool,
53    /// Second reserved bit.
54    pub rsv2: bool,
55    /// Third reserved bit.
56    pub rsv3: bool,
57    /// The opcode for this frame.
58    pub opcode: OpCode,
59    /// The masking key, if any.
60    pub mask: Option<[u8; 4]>,
61    /// The payload data.
62    pub payload: Vec<u8>,
63}
64
65impl Frame {
66    /// Creates a new Frame with the given opcode and payload.
67    pub fn new(opcode: OpCode, payload: Vec<u8>) -> Self {
68        Frame {
69            fin: true,
70            rsv1: false,
71            rsv2: false,
72            rsv3: false,
73            opcode,
74            mask: None,
75            payload,
76        }
77    }
78
79    /// Creates a close frame with an optional status code.
80    pub fn close(status_code: Option<u16>) -> Self {
81        let payload = status_code.map(|code| code.to_be_bytes().to_vec()).unwrap_or_default();
82        Frame::new(OpCode::Close, payload)
83    }
84
85    /// Checks if this frame is a close frame.
86    pub fn is_close(&self) -> bool {
87        self.opcode == OpCode::Close
88    }
89
90    /// Checks if this frame is masked.
91    pub fn is_masked(&self) -> bool {
92        self.mask.is_some()
93    }
94
95    /// Reads a frame from the given AsyncRead stream.
96    ///
97    /// # Errors
98    ///
99    /// Returns an error if reading from the stream fails or if the frame is invalid.
100    pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Self> {
101        let mut buf = [0u8; 2];
102        reader.read_exact(&mut buf).await?;
103        let first_byte = buf[0];
104        let second_byte = buf[1];
105
106        let fin = (first_byte & 0x80) != 0;
107        let rsv1 = (first_byte & 0x40) != 0;
108        let rsv2 = (first_byte & 0x20) != 0;
109        let rsv3 = (first_byte & 0x10) != 0;
110        let opcode = OpCode::try_from(first_byte & 0x0F)?;
111
112        let masked = (second_byte & 0x80) != 0;
113        let mut payload_len = (second_byte & 0x7F) as u64;
114
115        if payload_len == 126 {
116            let mut buf = [0u8; 2];
117            reader.read_exact(&mut buf).await?;
118            payload_len = u16::from_be_bytes(buf) as u64;
119        } else if payload_len == 127 {
120            let mut buf = [0u8; 8];
121            reader.read_exact(&mut buf).await?;
122            payload_len = u64::from_be_bytes(buf);
123        }
124
125        let mask = if masked {
126            let mut mask_bytes = [0u8; 4];
127            reader.read_exact(&mut mask_bytes).await?;
128            Some(mask_bytes)
129        } else {
130            None
131        };
132
133        let mut payload = vec![0u8; payload_len as usize];
134        reader.read_exact(&mut payload).await?;
135
136        if let Some(mask) = mask {
137            for (i, byte) in payload.iter_mut().enumerate() {
138                *byte ^= mask[i % 4];
139            }
140        }
141
142        Ok(Frame {
143            fin,
144            rsv1,
145            rsv2,
146            rsv3,
147            opcode,
148            mask,
149            payload,
150        })
151    }
152
153    /// Converts the frame to a byte vector.
154    pub fn to_bytes(&self) -> Vec<u8> {
155        let mut bytes = Vec::new();
156
157        let mut first_byte = 0;
158        if self.fin {
159            first_byte |= 0x80;
160        }
161        if self.rsv1 {
162            first_byte |= 0x40;
163        }
164        if self.rsv2 {
165            first_byte |= 0x20;
166        }
167        if self.rsv3 {
168            first_byte |= 0x10;
169        }
170        first_byte |= self.opcode as u8;
171        bytes.push(first_byte);
172
173        let mut second_byte = 0;
174        if self.mask.is_some() {
175            second_byte |= 0x80;
176        }
177
178        let payload_len = self.payload.len();
179        if payload_len < 126 {
180            second_byte |= payload_len as u8;
181            bytes.push(second_byte);
182        } else if payload_len < 65536 {
183            second_byte |= 126;
184            bytes.push(second_byte);
185            bytes.extend_from_slice(&(payload_len as u16).to_be_bytes());
186        } else {
187            second_byte |= 127;
188            bytes.push(second_byte);
189            bytes.extend_from_slice(&(payload_len as u64).to_be_bytes());
190        }
191
192        if let Some(mask) = self.mask {
193            bytes.extend_from_slice(&mask);
194        }
195
196        bytes.extend_from_slice(&self.payload);
197        bytes
198    }
199}