websocket_sans_io/
frame_decoding.rs

1use crate::{PayloadLength, Opcode, FrameInfo, masking};
2
3use nonmax::NonMaxU8;
4
5/// When large_frames` crate feature is on (by default), any bytes can be decoded, so no error possible.
6#[cfg(feature="large_frames")]
7pub type FrameDecoderError = core::convert::Infallible;
8
9/// When large_frames` crate feature is off (like now), WebSocket frame headers denoting large frames
10/// produce this error.
11#[allow(missing_docs)]
12#[cfg(not(feature="large_frames"))]
13#[derive(Debug,PartialEq, Eq, PartialOrd, Ord,Hash,Clone, Copy)]
14pub enum FrameDecoderError {
15    ExceededFrameSize,
16}
17
18#[derive(Clone, Copy, Debug)]
19struct SmallBufWithLen<const C: usize> {
20    len: u8,
21    data: [u8; C],
22}
23
24impl<const C: usize> SmallBufWithLen<C> {
25    /// Take as much bytes as possible from the slice pointer, updating it in process
26    fn slurp<'a, 'c>(&'c mut self, data: &'a mut [u8]) -> &'a mut [u8] {
27        let offset = self.len as usize;
28        let maxlen = (C - offset).min(data.len());
29        self.data[offset..(offset+maxlen)].copy_from_slice(&data[..maxlen]);
30        self.len += maxlen as u8;
31        &mut data[maxlen..]
32    }
33    fn is_full(&self) -> bool {
34        self.len as usize == C
35    }
36    const fn new() -> SmallBufWithLen<C> {
37        SmallBufWithLen {
38            len: 0,
39            data: [0u8; C],
40        }
41    }
42}
43
44/// Represents what data is expected to come next
45#[derive(Clone, Copy, Debug)]
46enum FrameDecodingState {
47    HeaderBeginning(SmallBufWithLen<2>),
48    PayloadLength16(SmallBufWithLen<2>),
49    #[cfg(feature="large_frames")]
50    PayloadLength64(SmallBufWithLen<8>),
51    MaskingKey(SmallBufWithLen<4>),
52    PayloadData {
53        phase: Option<NonMaxU8>,
54        remaining: PayloadLength,
55    },
56}
57
58impl Default for FrameDecodingState {
59    fn default() -> Self {
60        FrameDecodingState::HeaderBeginning(SmallBufWithLen::new())
61    }
62}
63
64/// A low-level WebSocket frames decoder.
65/// 
66/// It is a push parser: you can add offer it bytes that come from a socket and it emites events.
67/// 
68/// You typically need two loops to process incoming data: outer loop reads chunks of data
69/// from sockets, inner loop supplies this chunk to the decoder instance until no more events get emitted.
70/// 
71/// Example usage:
72/// 
73/// ```
74#[doc=include_str!("../examples/decode_frame.rs")]
75/// ```
76/// 
77/// Any sequence of bytes result in a some (sensial or not) [`WebsocketFrameEvent`]
78/// sequence (exception: when `large_frames` crate feature is disabled).
79/// 
80/// You may want to validate it (e.g. using [`FrameInfo::is_reasonable`] method) before using.
81#[derive(Clone, Copy, Debug, Default)]
82pub struct WebsocketFrameDecoder {
83    state: FrameDecodingState,
84    mask: [u8; 4],
85    basic_header: [u8; 2],
86    payload_length: PayloadLength,
87    original_opcode: Opcode,
88}
89
90/// Return value of [`WebsocketFrameDecoder::add_data`] call.
91#[derive(Debug,Clone)]
92pub struct WebsocketFrameDecoderAddDataResult {
93    /// Indicates how many bytes were consumed and should not be supplied again to
94    /// the subsequent invocation of [`WebsocketFrameDecoder::add_data`].
95    /// 
96    /// When `add_data` procudes [`WebsocketFrameEvent::PayloadChunk`], it also indicated how many
97    /// of the bytes in the buffer (starting from 0) should be used as a part of payload.
98    pub consumed_bytes: usize,
99    /// Emitted event, if any.
100    pub event: Option<WebsocketFrameEvent>,
101}
102
103#[allow(missing_docs)]
104/// Information that [`WebsocketFrameDecoder`] gives in return to bytes being fed to it.
105#[derive(Debug, PartialEq, Eq, Clone)]
106pub enum WebsocketFrameEvent {
107    /// Indicates a frame is started.
108    /// 
109    /// `original_opcode` is the same as `frame_info.opcode`, except for
110    /// [`Opcode::Continuation`] frames, for which it should refer to
111    /// initial frame in sequence (i.e. [`Opcode::Text`] or [`Opcode::Binary`])
112    Start{frame_info: FrameInfo, original_opcode: Opcode},
113
114    /// Bytes which were supplied to [`WebsocketFrameDecoder::add_data`] are payload bytes,
115    /// transformed for usage as a part of payload.
116    /// 
117    /// You should use [`WebsocketFrameDecoderAddDataResult::consumed_bytes`] to get actual
118    /// buffer to be handled as content coming from the WebSocket.
119    /// 
120    /// Mind the `original_opcode` to avoid mixing content of control frames and data frames.
121    PayloadChunk{ original_opcode: Opcode},
122
123    /// Indicates that all `PayloadChunk`s for the given frame are delivered and the frame
124    /// is ended.
125    /// 
126    /// You can watch for `frame_info.fin` together with checking `original_opcode` to know
127    /// wnen WebSocket **message** (not just a frame) ends.
128    /// 
129    /// `frame_info` is the same as in [`WebsocketFrameEvent::Start`]'s `frame_info`.
130    End{frame_info: FrameInfo, original_opcode: Opcode},
131}
132
133impl WebsocketFrameDecoder {
134    fn get_opcode(&self) -> Opcode {
135        use Opcode::*;
136        match self.basic_header[0] & 0xF {
137            0 => Continuation,
138            1 => Text,
139            2 => Binary,
140            3 => ReservedData3,
141            4 => ReservedData4,
142            5 => ReservedData5,
143            6 => ReservedData6,
144            7 => ReservedData7,
145            8 => ConnectionClose,
146            9 => Ping,
147            0xA => Pong,
148            0xB => ReservedControlB,
149            0xC => ReservedControlC,
150            0xD => ReservedControlD,
151            0xE => ReservedControlE,
152            0xF => ReservedControlF,
153            _ => unreachable!(),
154        }
155    }
156
157    /// Get frame info and original opcode
158    fn get_frame_info(&self, masked: bool) -> (FrameInfo, Opcode) {
159        let fi = FrameInfo {
160            opcode: self.get_opcode(),
161            payload_length: self.payload_length,
162            mask: if masked { Some(self.mask) } else { None },
163            fin: self.basic_header[0] & 0x80 == 0x80,
164            reserved: (self.basic_header[0] & 0x70) >> 4,
165        };
166        let mut original_opcode = fi.opcode;
167        if original_opcode==Opcode::Continuation {
168            original_opcode = self.original_opcode;
169        }
170        (fi, original_opcode)
171    }
172
173    /// Add some bytes to the decoder and return events, if any.
174    /// 
175    /// Call this function again if any of the following conditions are met:
176    ///
177    /// * When new incoming data is available on the socket
178    /// * When previous invocation of `add_data` returned nonzero [`WebsocketFrameDecoderAddDataResult::consumed_bytes`].
179    /// * When previous invocation of `add_data` returned non-`None` [`WebsocketFrameDecoderAddDataResult::event`].
180    /// 
181    /// You may need call it with empty `data` buffer to get some final [`WebsocketFrameEvent::End`].
182    /// 
183    /// Input buffer needs to be mutable because it is also used to transform (unmask)
184    /// payload content chunks in-place.
185    pub fn add_data<'a, 'b>(
186        &'a mut self,
187        mut data: &'b mut [u8],
188    ) -> Result<WebsocketFrameDecoderAddDataResult, FrameDecoderError> {
189        let original_data_len = data.len();
190        loop {
191            macro_rules! return_dummy {
192                () => {
193                    return Ok(WebsocketFrameDecoderAddDataResult {
194                        consumed_bytes: original_data_len - data.len(),
195                        event: None,
196                    });
197                };
198            }
199            if data.len() == 0 && ! matches!(self.state, FrameDecodingState::PayloadData{remaining: 0, ..}) {
200                return_dummy!();
201            }
202            macro_rules! try_to_fill_buffer_or_return {
203                ($v:ident) => {
204                    data = $v.slurp(data);
205                    if !$v.is_full() {
206                        assert!(data.is_empty());
207                        return_dummy!();
208                    }
209                    let $v = $v.data;
210                };
211            }
212            let mut length_is_ready = false;
213            match self.state {
214                FrameDecodingState::HeaderBeginning(ref mut v) => {
215                    try_to_fill_buffer_or_return!(v);
216                    self.basic_header = v;
217                    let opcode = self.get_opcode();
218                    if opcode.is_data() && opcode != Opcode::Continuation {
219                        self.original_opcode = opcode;
220                    }
221                    match self.basic_header[1] & 0x7F {
222                        0x7E => {
223                            self.state = FrameDecodingState::PayloadLength16(SmallBufWithLen::new())
224                        }
225                        #[cfg(feature="large_frames")]
226                        0x7F => {
227                            self.state = FrameDecodingState::PayloadLength64(SmallBufWithLen::new())
228                        }
229                        #[cfg(not(feature="large_frames"))] 0x7F => {
230                            return Err(FrameDecoderError::ExceededFrameSize);
231                        }
232                        x => {
233                            self.payload_length = x.into();
234                            length_is_ready = true;
235                        }
236                    };
237                }
238                FrameDecodingState::PayloadLength16(ref mut v) => {
239                    try_to_fill_buffer_or_return!(v);
240                    self.payload_length = u16::from_be_bytes(v).into();
241                    length_is_ready = true;
242                }
243                #[cfg(feature="large_frames")]
244                FrameDecodingState::PayloadLength64(ref mut v) => {
245                    try_to_fill_buffer_or_return!(v);
246                    self.payload_length = u64::from_be_bytes(v);
247                    length_is_ready = true;
248                }
249                FrameDecodingState::MaskingKey(ref mut v) => {
250                    try_to_fill_buffer_or_return!(v);
251                    self.mask = v;
252                    self.state = FrameDecodingState::PayloadData {
253                        phase: Some(NonMaxU8::default()),
254                        remaining: self.payload_length,
255                    };
256                    let (frame_info, original_opcode) = self.get_frame_info(true);
257                    return Ok(WebsocketFrameDecoderAddDataResult {
258                        consumed_bytes: original_data_len - data.len(),
259                        event: Some(WebsocketFrameEvent::Start{frame_info, original_opcode}),
260                    });
261                }
262                FrameDecodingState::PayloadData {
263                    phase,
264                    remaining: 0,
265                } => {
266                    self.state = FrameDecodingState::HeaderBeginning(SmallBufWithLen::new());
267                    let (fi, original_opcode) = self.get_frame_info(phase.is_some());
268                    if fi.opcode.is_data() && fi.fin {
269                        self.original_opcode = Opcode::Continuation;
270                    }
271                    return Ok(WebsocketFrameDecoderAddDataResult {
272                        consumed_bytes: original_data_len - data.len(),
273                        event: Some(WebsocketFrameEvent::End{frame_info: fi, original_opcode}
274                            ),
275                    });
276                }
277                FrameDecodingState::PayloadData {
278                    ref mut phase,
279                    ref mut remaining,
280                } => {
281                    let start_offset = original_data_len - data.len();
282                    let mut max_len = data.len();
283                    if let Ok(remaining_usize) = usize::try_from(*remaining) {
284                        max_len = max_len.min(remaining_usize);
285                    }
286                    let (payload_chunk, _rest) = data.split_at_mut(max_len);
287
288                    if let Some(phase) = phase {
289                        let mut ph = phase.get();
290                        masking::apply_mask(self.mask, payload_chunk, ph);
291                        ph += payload_chunk.len() as u8;
292                        *phase = NonMaxU8::new(ph & 0x03).unwrap();
293                    }
294
295                    *remaining -= max_len as PayloadLength;
296                    let mut original_opcode = self.get_opcode();
297                    if original_opcode == Opcode::Continuation {
298                        original_opcode = self.original_opcode;
299                    }
300                    assert_eq!(start_offset, 0);
301                    return Ok(WebsocketFrameDecoderAddDataResult {
302                        consumed_bytes: max_len,
303                        event: Some(WebsocketFrameEvent::PayloadChunk{original_opcode}),
304                    });
305                }
306            }
307            if length_is_ready {
308                if self.basic_header[1] & 0x80 == 0x80 {
309                    self.state = FrameDecodingState::MaskingKey(SmallBufWithLen::new());
310                } else {
311                    self.state = FrameDecodingState::PayloadData {
312                        phase: None,
313                        remaining: self.payload_length,
314                    };
315                    let (frame_info, original_opcode) = self.get_frame_info(false);
316                    return Ok(WebsocketFrameDecoderAddDataResult {
317                        consumed_bytes: original_data_len - data.len(),
318                        event: Some(WebsocketFrameEvent::Start{frame_info, original_opcode}),
319                    });
320                }
321            }
322        }
323    }
324
325    /// There is no incomplete WebSocket frame at this moment and EOF is valid here.
326    ///
327    /// This method is not related to [`Opcode::ConnectionClose`] in any way.
328    #[inline]
329    pub fn eof_valid(&self) -> bool {
330        matches!(self.state, FrameDecodingState::HeaderBeginning(..))
331    }
332
333    /// Create new instance.
334    #[inline]
335    pub const fn new() -> Self {
336        WebsocketFrameDecoder {
337            state: FrameDecodingState::HeaderBeginning(SmallBufWithLen::new()),
338            mask: [0; 4],
339            basic_header: [0; 2],
340            payload_length: 0,
341            original_opcode: Opcode::Continuation,
342        }
343    }
344}