ws_tool/
frame.rs

1use crate::codec::apply_mask;
2use bytes::{BufMut, BytesMut};
3use std::fmt::Debug;
4
5/// Defines the interpretation of the "Payload data".  If an unknown
6/// opcode is received, the receiving endpoint MUST _Fail the
7/// WebSocket Connection_.  The following values are defined.
8/// - x0 denotes a continuation frame
9/// - x1 denotes a text frame
10/// - x2 denotes a binary frame
11/// - x3-7 are reserved for further non-control frames
12/// - x8 denotes a connection close
13/// - x9 denotes a ping
14/// - xA denotes a pong
15/// - xB-F are reserved for further control frames
16#[derive(Debug, Clone, PartialEq, Eq, Copy)]
17#[repr(u8)]
18pub enum OpCode {
19    /// - x0 denotes a continuation frame
20    Continue = 0,
21    /// - x1 denotes a text frame
22    Text = 1,
23    /// - x2 denotes a binary frame
24    Binary = 2,
25    /// - x3-7 are reserved for further non-control frames
26    RNC3 = 3,
27    /// - x3-7 are reserved for further non-control frames
28    RNC4 = 4,
29    /// - x3-7 are reserved for further non-control frames
30    RNC5 = 5,
31    /// - x3-7 are reserved for further non-control frames
32    RNC6 = 6,
33    /// - x3-7 are reserved for further non-control frames
34    RNC7 = 7,
35    /// - x8 denotes a connection close
36    Close = 8,
37    /// - x9 denotes a ping
38    Ping = 9,
39    /// - xA denotes a pong
40    Pong = 10,
41    /// - xB-F are reserved for further control frames
42    RC11 = 11,
43    /// - xB-F are reserved for further control frames
44    RC12 = 12,
45    /// - xB-F are reserved for further control frames
46    RC13 = 13,
47    /// - xB-F are reserved for further control frames
48    RC14 = 14,
49    /// - xB-F are reserved for further control frames
50    RC15 = 15,
51}
52
53impl Default for OpCode {
54    fn default() -> Self {
55        Self::Text
56    }
57}
58
59impl OpCode {
60    /// get corresponding u8 value
61    pub fn as_u8(&self) -> u8 {
62        *self as u8
63    }
64
65    /// check is close type frame
66    pub fn is_close(&self) -> bool {
67        matches!(self, Self::Close)
68    }
69
70    /// check is text/binary ?
71    pub fn is_data(&self) -> bool {
72        matches!(self, Self::Text | Self::Binary | Self::Continue)
73    }
74
75    /// check is reserved
76    pub fn is_reserved(&self) -> bool {
77        matches!(self.as_u8(), 3..=5 | 11..=15)
78    }
79}
80
81#[inline]
82pub(crate) fn parse_opcode(val: u8) -> OpCode {
83    unsafe { std::mem::transmute(val & 0b00001111) }
84}
85
86#[inline]
87pub(crate) fn get_bit(source: &[u8], byte_idx: usize, bit_idx: u8) -> bool {
88    let mask = match bit_idx {
89        0 => 128,
90        1 => 64,
91        2 => 32,
92        3 => 16,
93        4 => 8,
94        5 => 4,
95        6 => 2,
96        7 => 1,
97        _ => unreachable!(),
98    };
99    unsafe { *source.get_unchecked(byte_idx) & mask == mask }
100}
101
102#[inline]
103pub(crate) fn set_bit(source: &mut [u8], byte_idx: usize, bit_idx: u8, val: bool) {
104    if val {
105        let mask = match bit_idx {
106            0 => 128,
107            1 => 64,
108            2 => 32,
109            3 => 16,
110            4 => 8,
111            5 => 4,
112            6 => 2,
113            7 => 1,
114            _ => unreachable!(),
115        };
116        source[byte_idx] |= mask;
117    } else {
118        let mask = match bit_idx {
119            0 => 0b01111111,
120            1 => 0b10111111,
121            2 => 0b11011111,
122            3 => 0b11101111,
123            4 => 0b11110111,
124            5 => 0b11111011,
125            6 => 0b11111101,
126            7 => 0b11111110,
127            _ => unreachable!(),
128        };
129        source[byte_idx] &= mask;
130    }
131}
132
133macro_rules! impl_get {
134    () => {
135        #[inline]
136        fn get_bit(&self, byte_idx: usize, bit_idx: u8) -> bool {
137            get_bit(&self.0, byte_idx, bit_idx)
138        }
139
140        /// get fin bit value
141        #[inline]
142        pub fn fin(&self) -> bool {
143            self.get_bit(0, 0)
144        }
145
146        /// get rsv1 bit value
147        #[inline]
148        pub fn rsv1(&self) -> bool {
149            self.get_bit(0, 1)
150        }
151
152        /// get rsv2 bit value
153        #[inline]
154        pub fn rsv2(&self) -> bool {
155            self.get_bit(0, 2)
156        }
157
158        /// get rsv3 bit value
159        #[inline]
160        pub fn rsv3(&self) -> bool {
161            self.get_bit(0, 3)
162        }
163
164        /// return frame opcode
165        #[inline]
166        pub fn opcode(&self) -> OpCode {
167            parse_opcode(unsafe { *self.0.get_unchecked(0) })
168        }
169
170        /// get mask bit value
171        #[inline]
172        pub fn masked(&self) -> bool {
173            self.get_bit(1, 0)
174        }
175
176        #[inline]
177        fn len_bytes(&self) -> usize {
178            let header = &self.0;
179            match header[1] {
180                0..=125 | 128..=253 => 1,
181                126 | 254 => 3,
182                127 | 255 => 9,
183            }
184        }
185
186        /// return **payload** len
187        #[inline]
188        pub fn payload_len(&self) -> u64 {
189            let header = &self.0;
190            assert!(header.len() >= 1);
191            match header[1] {
192                len @ (0..=125 | 128..=253) => (len & 127) as u64,
193                126 | 254 => {
194                    assert!(header.len() >= 4);
195                    u16::from_be_bytes((&header[2..4]).try_into().unwrap()) as u64
196                }
197                127 | 255 => {
198                    assert!(header.len() >= 10);
199                    u64::from_be_bytes((&header[2..(8 + 2)]).try_into().unwrap())
200                }
201            }
202        }
203
204        /// get frame mask key
205        #[inline]
206        pub fn masking_key(&self) -> Option<[u8; 4]> {
207            if self.masked() {
208                let len_occupied = self.len_bytes();
209                let mut arr = [0u8; 4];
210                arr.copy_from_slice(&self.0[(1 + len_occupied)..(5 + len_occupied)]);
211                Some(arr)
212            } else {
213                None
214            }
215        }
216    };
217}
218
219/// get expected header len
220pub fn header_len(mask: bool, payload_len: u64) -> usize {
221    let mut header_len = 1;
222    if mask {
223        header_len += 4;
224    }
225    if payload_len <= 125 {
226        header_len += 1;
227    } else if payload_len <= 65535 {
228        header_len += 3;
229    } else {
230        header_len += 9;
231    }
232    header_len
233}
234
235#[inline]
236const fn first_byte(fin: bool, rsv1: bool, rsv2: bool, rsv3: bool, opcode: OpCode) -> u8 {
237    let leading = match (fin, rsv1, rsv2, rsv3) {
238        (true, true, true, true) => 0b1111_0000,
239        (true, true, true, false) => 0b1110_0000,
240        (true, true, false, true) => 0b1101_0000,
241        (true, true, false, false) => 0b1100_0000,
242        (true, false, true, true) => 0b1011_0000,
243        (true, false, true, false) => 0b1010_0000,
244        (true, false, false, true) => 0b1001_0000,
245        (true, false, false, false) => 0b1000_0000,
246        (false, true, true, true) => 0b0111_0000,
247        (false, true, true, false) => 0b0110_0000,
248        (false, true, false, true) => 0b0101_0000,
249        (false, true, false, false) => 0b0100_0000,
250        (false, false, true, true) => 0b0011_0000,
251        (false, false, true, false) => 0b0010_0000,
252        (false, false, false, true) => 0b0001_0000,
253        (false, false, false, false) => 0b0000_0000,
254    };
255    leading | opcode as u8
256}
257
258/// write header without allocation
259#[allow(clippy::too_many_arguments)]
260pub fn ctor_header<M: Into<Option<[u8; 4]>>>(
261    buf: &mut [u8],
262    fin: bool,
263    rsv1: bool,
264    rsv2: bool,
265    rsv3: bool,
266    mask_key: M,
267    opcode: OpCode,
268    payload_len: u64,
269) -> &[u8] {
270    let mask = mask_key.into();
271    let mut header_len = 1;
272    if mask.is_some() {
273        header_len += 4;
274    }
275    if payload_len <= 125 {
276        buf[1] = payload_len as u8;
277        header_len += 1;
278    } else if payload_len <= 65535 {
279        buf[1] = 126;
280        buf[2..4].copy_from_slice(&(payload_len as u16).to_be_bytes());
281        header_len += 3;
282    } else {
283        buf[1] = 127;
284        buf[2..10].copy_from_slice(&payload_len.to_be_bytes());
285        header_len += 9;
286    }
287    buf[0] = first_byte(fin, rsv1, rsv2, rsv3, opcode);
288    if let Some(key) = mask {
289        set_bit(buf, 1, 0, true);
290        buf[(header_len - 4)..header_len].copy_from_slice(&key);
291    } else {
292        set_bit(buf, 1, 0, false);
293    }
294    &buf[..header_len]
295}
296
297#[test]
298fn test_header() {
299    fn rand_mask() -> Option<[u8; 4]> {
300        fastrand::bool().then(|| fastrand::u32(0..u32::MAX).to_be_bytes())
301    }
302
303    fn rand_code() -> OpCode {
304        unsafe { std::mem::transmute(fastrand::u8(0..16)) }
305    }
306
307    let mut buf = [0u8; 14];
308    for _ in 0..1000 {
309        let fin = fastrand::bool();
310        let rsv1 = fastrand::bool();
311        let rsv2 = fastrand::bool();
312        let rsv3 = fastrand::bool();
313        let mask_key = rand_mask();
314        let opcode = rand_code();
315        let payload_len = fastrand::u64(0..u64::MAX);
316
317        let slice = ctor_header(
318            &mut buf,
319            fin,
320            rsv1,
321            rsv2,
322            rsv3,
323            mask_key,
324            opcode,
325            payload_len,
326        );
327        let header = Header::new(fin, rsv1, rsv2, rsv3, mask_key, opcode, payload_len);
328        assert_eq!(slice, &header.0.to_vec());
329    }
330}
331
332/// header with less info
333#[derive(Debug, Clone, Copy)]
334pub struct SimplifiedHeader {
335    /// fin
336    pub fin: bool,
337    /// compressed bit
338    pub rsv1: bool,
339    /// reserved
340    pub rsv2: bool,
341    /// reserved
342    pub rsv3: bool,
343    /// frame type
344    pub code: OpCode,
345}
346
347impl<'a> From<HeaderView<'a>> for SimplifiedHeader {
348    fn from(value: HeaderView<'a>) -> Self {
349        Self {
350            fin: value.fin(),
351            rsv1: value.rsv1(),
352            rsv2: value.rsv2(),
353            rsv3: value.rsv3(),
354            code: value.opcode(),
355        }
356    }
357}
358
359/// frame header
360#[derive(Debug, Clone, Copy)]
361pub struct HeaderView<'a>(pub(crate) &'a [u8]);
362
363impl<'a> HeaderView<'a> {
364    impl_get! {}
365}
366
367/// owned header buf
368#[derive(Debug, Clone)]
369pub struct Header(pub(crate) BytesMut);
370
371impl Header {
372    impl_get! {}
373    /// get header as bytes
374    pub fn as_bytes(&self) -> &[u8] {
375        &self.0
376    }
377
378    #[inline]
379    fn set_bit(&mut self, byte_idx: usize, bit_idx: u8, val: bool) {
380        set_bit(&mut self.0, byte_idx, bit_idx, val)
381    }
382
383    /// set fin bit
384    #[inline]
385    pub fn set_fin(&mut self, val: bool) {
386        self.set_bit(0, 0, val)
387    }
388
389    /// set rsv1 bit
390    #[inline]
391    pub fn set_rsv1(&mut self, val: bool) {
392        self.set_bit(0, 1, val)
393    }
394
395    /// set rsv2 bit
396    #[inline]
397    pub fn set_rsv2(&mut self, val: bool) {
398        self.set_bit(0, 2, val)
399    }
400
401    /// set rsv3 bit
402    #[inline]
403    pub fn set_rsv3(&mut self, val: bool) {
404        self.set_bit(0, 3, val)
405    }
406
407    /// set opcode
408    #[inline]
409    pub fn set_opcode(&mut self, code: OpCode) {
410        let header = &mut self.0;
411        let leading_bits = (header[0] >> 4) << 4;
412        header[0] = leading_bits | code.as_u8()
413    }
414
415    /// **NOTE** if change mask bit after setting payload
416    /// you need to set payload again to adjust data frame
417    #[inline]
418    pub fn set_mask(&mut self, mask: bool) {
419        self.set_bit(1, 0, mask);
420    }
421
422    /// set header payload lens
423    /// TODO do not overlay mask key
424    #[inline]
425    pub fn set_payload_len(&mut self, len: u64) {
426        let mask = self.masking_key();
427        let mask_len = mask.as_ref().map(|_| 4).unwrap_or_default();
428        let header = &mut self.0;
429        let mut leading_byte = header[1];
430        match len {
431            0..=125 => {
432                leading_byte &= 128;
433                header[1] = leading_byte | (len as u8);
434                let idx = 1 + 1;
435                header.resize(idx + mask_len, 0);
436                if let Some(mask) = mask {
437                    header[idx..].copy_from_slice(&mask);
438                }
439            }
440            126..=65535 => {
441                leading_byte &= 128;
442                header[1] = leading_byte | 126;
443                let len_arr = (len as u16).to_be_bytes();
444                let idx = 1 + 3;
445                header.resize(idx + mask_len, 0);
446                header[2] = len_arr[0];
447                header[3] = len_arr[1];
448                if let Some(mask) = mask {
449                    header[idx..].copy_from_slice(&mask);
450                }
451            }
452            _ => {
453                leading_byte &= 128;
454                header[1] = leading_byte | 127;
455                let len_arr = len.to_be_bytes();
456                let idx = 1 + 9;
457                header.resize(idx + mask_len, 0);
458                header[2..10].copy_from_slice(&len_arr[..8]);
459                if let Some(mask) = mask {
460                    header[idx..].copy_from_slice(&mask);
461                }
462            }
463        }
464    }
465
466    /// construct header without checking
467    pub fn raw(data: BytesMut) -> Self {
468        Self(data)
469    }
470
471    /// construct new header
472    pub fn new<M: Into<Option<[u8; 4]>>>(
473        fin: bool,
474        rsv1: bool,
475        rsv2: bool,
476        rsv3: bool,
477        mask_key: M,
478        opcode: OpCode,
479        payload_len: u64,
480    ) -> Self {
481        let mask = mask_key.into();
482        let len = header_len(mask.is_some(), payload_len);
483        assert!(len >= 2);
484        let mut buf = BytesMut::zeroed(len);
485        buf[0] = first_byte(fin, rsv1, rsv2, rsv3, opcode);
486        let mut header = Self(buf);
487        header.set_mask(mask.is_some());
488        header.set_payload_len(payload_len);
489        if let Some(mask) = mask {
490            header.0[(len - 4)..len].copy_from_slice(&mask);
491        }
492        header
493    }
494}
495
496/// owned frame
497#[derive(Debug, Clone)]
498pub struct OwnedFrame {
499    pub(crate) header: Header,
500    pub(crate) payload: BytesMut,
501}
502
503impl OwnedFrame {
504    /// construct new owned frame
505    #[inline]
506    pub fn new(code: OpCode, mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
507        let header = Header::new(true, false, false, false, mask, code, data.len() as u64);
508        let mut payload = BytesMut::with_capacity(data.len());
509        payload.extend_from_slice(data);
510        if let Some(mask) = header.masking_key() {
511            apply_mask(&mut payload, mask);
512        }
513        Self { header, payload }
514    }
515
516    /// use constructed header and payload
517    ///
518    /// **NOTE**: this will not check header and payload
519    #[inline]
520    pub fn with_raw(header: Header, payload: BytesMut) -> Self {
521        Self { header, payload }
522    }
523
524    /// helper function to construct a text frame
525    #[inline]
526    pub fn text_frame(mask: impl Into<Option<[u8; 4]>>, data: &str) -> Self {
527        Self::new(OpCode::Text, mask, data.as_bytes())
528    }
529
530    /// helper function to construct a binary frame
531    #[inline]
532    pub fn binary_frame(mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
533        Self::new(OpCode::Binary, mask, data)
534    }
535
536    /// helper function to construct a ping frame
537    #[inline]
538    pub fn ping_frame(mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
539        assert!(data.len() <= 125);
540        Self::new(OpCode::Ping, mask, data)
541    }
542
543    /// helper function to construct a pong frame
544    #[inline]
545    pub fn pong_frame(mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
546        assert!(data.len() <= 125);
547        Self::new(OpCode::Pong, mask, data)
548    }
549
550    /// helper function to construct a close frame
551    #[inline]
552    pub fn close_frame(
553        mask: impl Into<Option<[u8; 4]>>,
554        code: impl Into<Option<u16>>,
555        data: &[u8],
556    ) -> Self {
557        assert!(data.len() <= 123);
558        let code = code.into();
559        assert!(code.is_some() || data.is_empty());
560        let mut payload = BytesMut::with_capacity(2 + data.len());
561        if let Some(code) = code {
562            payload.put_u16(code);
563            payload.extend_from_slice(data);
564        }
565        Self::new(OpCode::Close, mask, &payload)
566    }
567
568    /// unmask frame if masked
569    #[inline]
570    pub fn unmask(&mut self) -> Option<[u8; 4]> {
571        if let Some(mask) = self.header.masking_key() {
572            apply_mask(&mut self.payload, mask);
573            self.header.set_mask(false);
574            self.header.0.truncate(self.header.0.len() - 4);
575            Some(mask)
576        } else {
577            None
578        }
579    }
580
581    /// mask frame with provide mask key
582    ///
583    /// this will override old mask
584    pub fn mask(&mut self, mask: [u8; 4]) {
585        self.unmask();
586        self.header.set_mask(true);
587        self.header.0.extend_from_slice(&mask);
588        apply_mask(&mut self.payload, mask);
589    }
590
591    /// extend frame payload
592    ///
593    /// **NOTE** this function will unmask first, and then extend payload, mask with old
594    /// mask key finally
595    pub fn extend_from_slice(&mut self, data: &[u8]) {
596        if let Some(mask) = self.unmask() {
597            self.payload.extend_from_slice(data);
598            self.header.set_payload_len(self.payload.len() as u64);
599            self.mask(mask);
600        } else {
601            self.payload.extend_from_slice(data);
602            self.header.set_payload_len(self.payload.len() as u64);
603        }
604    }
605
606    /// get frame header
607    #[inline]
608    pub fn header(&self) -> &Header {
609        &self.header
610    }
611
612    /// get mutable frame header
613    #[inline]
614    pub fn header_mut(&mut self) -> &mut Header {
615        &mut self.header
616    }
617
618    /// get payload
619    #[inline]
620    pub fn payload(&self) -> &BytesMut {
621        &self.payload
622    }
623
624    /// consume frame return header and payload
625    #[inline]
626    pub fn parts(self) -> (Header, BytesMut) {
627        (self.header, self.payload)
628    }
629}