ws_tool/codec/frame/
mod.rs

1use crate::errors::{ProtocolError, WsError};
2use crate::frame::{get_bit, HeaderView, OpCode, SimplifiedHeader};
3use http;
4use crate::protocol::{cal_accept_key, standard_handshake_req_check};
5use bytes::BytesMut;
6use std::fmt::Debug;
7use std::ops::Range;
8
9#[cfg(feature = "sync")]
10mod blocking;
11
12#[cfg(feature = "sync")]
13pub use blocking::*;
14
15#[cfg(feature = "async")]
16mod non_blocking;
17
18#[cfg(feature = "async")]
19pub use non_blocking::*;
20
21/// text frame utf-8 checking policy
22#[derive(Debug, Clone)]
23pub enum ValidateUtf8Policy {
24    /// no not validate utf
25    Off,
26    /// fail if fragment frame payload is not valid utf8
27    FastFail,
28    /// check utf8 after merged
29    On,
30}
31
32#[allow(missing_docs)]
33impl ValidateUtf8Policy {
34    pub fn should_check(&self) -> bool {
35        !matches!(self, Self::Off)
36    }
37
38    pub fn is_fast_fail(&self) -> bool {
39        matches!(self, Self::FastFail)
40    }
41}
42
43/// frame send/recv config
44#[derive(Debug, Clone)]
45pub struct FrameConfig {
46    /// check rsv1 bits
47    pub check_rsv: bool,
48    /// auto mask send frame payload, for client, it must be true
49    pub mask_send_frame: bool,
50    /// allocate new buf for every frame
51    pub renew_buf_on_write: bool,
52    /// auto unmask a masked frame payload
53    pub auto_unmask: bool,
54    /// limit max payload size
55    pub max_frame_payload_size: usize,
56    /// auto split size, if set 0, do not split frame
57    pub auto_fragment_size: usize,
58    /// auto merge fragmented frames into one frame
59    pub merge_frame: bool,
60    /// utf8 check policy
61    pub validate_utf8: ValidateUtf8Policy,
62    /// resize size of read buf, default 4K
63    pub resize_size: usize,
64    /// if available len < resize, resize read buf, default 1K
65    pub resize_thresh: usize,
66}
67
68impl Default for FrameConfig {
69    fn default() -> Self {
70        Self {
71            check_rsv: true,
72            mask_send_frame: true,
73            renew_buf_on_write: false,
74            auto_unmask: true,
75            max_frame_payload_size: 0,
76            auto_fragment_size: 0,
77            merge_frame: true,
78            validate_utf8: ValidateUtf8Policy::FastFail,
79            resize_size: 4096,
80            resize_thresh: 1024,
81        }
82    }
83}
84
85/// apply websocket mask to buf by given key
86#[inline]
87pub fn apply_mask(buf: &mut [u8], mask: [u8; 4]) {
88    apply_mask_array_chunk(buf, mask)
89}
90
91#[inline]
92fn apply_mask_array_chunk(buf: &mut [u8], mask: [u8; 4]) {
93    let mask32 = u32::from_ne_bytes(mask);
94    let mut iter = buf.chunks_exact_mut(4);
95    while let Some(chunk) = iter.next() {
96        let val: &mut u32 = unsafe { std::mem::transmute(chunk.as_mut_ptr().cast::<u32>()) };
97        *val ^= mask32;
98    }
99    for (i, byte) in iter.into_remainder().iter_mut().enumerate() {
100        *byte ^= mask[i & 3];
101    }
102}
103
104/// websocket frame reader
105pub struct FrameReadState {
106    fragmented: bool,
107    config: FrameConfig,
108    fragmented_data: Vec<u8>,
109    fragmented_type: OpCode,
110    buf: FrameBuffer,
111}
112
113impl Default for FrameReadState {
114    fn default() -> Self {
115        Self {
116            fragmented: false,
117            config: Default::default(),
118            fragmented_data: vec![],
119            fragmented_type: OpCode::default(),
120            buf: FrameBuffer::new(),
121        }
122    }
123}
124
125impl FrameReadState {
126    /// construct with config
127    pub fn with_config(config: FrameConfig) -> Self {
128        Self {
129            config,
130            ..Self::default()
131        }
132    }
133
134    /// check if data in buffer is enough to parse frame header
135    pub fn is_header_ok(&self) -> bool {
136        let ava_data = self.buf.ava_data();
137        if ava_data.len() < 2 {
138            false
139        } else {
140            let len = ava_data[1] & 0b01111111;
141            let mask = get_bit(&ava_data, 1, 0);
142            let mut min_len = match len {
143                0..=125 => 2,
144                126 => 4,
145                127 => 10,
146                _ => unreachable!(),
147            };
148            if mask {
149                min_len += 4;
150            }
151            ava_data.len() >= min_len
152        }
153    }
154
155    /// return current frame header bits of buffer
156    #[inline]
157    pub fn get_leading_bits(&self) -> u8 {
158        self.buf.ava_data()[0] >> 4
159    }
160
161    /// try to parse frame header in buffer, return (header_len, payload_len, header_len + payload_len)
162    #[inline]
163    pub fn parse_frame_header(&mut self) -> Result<(usize, usize, usize), WsError> {
164        let ava_data = self.buf.ava_data();
165        let leading_bits = self.get_leading_bits();
166        let max_payload_size = self.config.max_frame_payload_size;
167        let check_rsv = self.config.check_rsv;
168
169        fn parse_payload_len(source: &[u8]) -> Result<(usize, usize), ProtocolError> {
170            match source[1] {
171                len @ (0..=125 | 128..=253) => Ok((1, (len & 127) as usize)),
172                126 | 254 => {
173                    if source.len() < 4 {
174                        return Err(ProtocolError::InsufficientLen(source.len()));
175                    }
176                    Ok((
177                        1 + 2,
178                        u16::from_be_bytes((&source[2..4]).try_into().unwrap()) as usize,
179                    ))
180                }
181                127 | 255 => {
182                    if source.len() < 10 {
183                        return Err(ProtocolError::InsufficientLen(source.len()));
184                    }
185                    Ok((
186                        1 + 8,
187                        usize::from_be_bytes((&source[2..(8 + 2)]).try_into().unwrap()),
188                    ))
189                }
190            }
191        }
192
193        if check_rsv && !(leading_bits == 0b00001000 || leading_bits == 0b00000000) {
194            return Err(WsError::ProtocolError {
195                close_code: 1008,
196                error: ProtocolError::InvalidLeadingBits(leading_bits),
197            });
198        }
199        let (len_occ_bytes, payload_len) =
200            parse_payload_len(ava_data).map_err(|e| WsError::ProtocolError {
201                close_code: 1008,
202                error: e,
203            })?;
204
205        if max_payload_size > 0 && payload_len > max_payload_size {
206            return Err(WsError::ProtocolError {
207                close_code: 1008,
208                error: ProtocolError::PayloadTooLarge(max_payload_size),
209            });
210        }
211        let mask = get_bit(ava_data, 1, 0);
212        let header_len = 1 + len_occ_bytes + if mask { 4 } else { 0 };
213        Ok((header_len, payload_len, header_len + payload_len))
214    }
215
216    /// get a frame and reset state
217    #[inline]
218    pub fn consume_frame(
219        &mut self,
220        header_len: usize,
221        payload_len: usize,
222        total_len: usize,
223    ) -> (SimplifiedHeader, Range<usize>) {
224        let buf = &mut self.buf;
225        let auto_unmask = self.config.auto_unmask;
226
227        let ava_data = buf.ava_mut_data();
228        let (header_data, remain) = ava_data.split_at_mut(header_len);
229        let header = HeaderView(header_data);
230        let payload = remain.split_at_mut(payload_len).0;
231        if auto_unmask {
232            if let Some(mask) = header.masking_key() {
233                apply_mask(payload, mask)
234            }
235        }
236        let header: SimplifiedHeader = header.into();
237        let s_idx = buf.consume_idx + header_len;
238        let e_idx = s_idx + payload_len;
239        buf.consume(total_len);
240        (header, s_idx..e_idx)
241    }
242
243    fn check_frame(
244        &mut self,
245        header: SimplifiedHeader,
246        range: Range<usize>,
247    ) -> Result<(), WsError> {
248        let fragmented = &mut self.fragmented;
249        let utf8_policy = &self.config.validate_utf8;
250        let payload = &self.buf.buf[range];
251        match header.code {
252            OpCode::Continue => {
253                if !*fragmented {
254                    return Err(WsError::ProtocolError {
255                        close_code: 1002,
256                        error: ProtocolError::MissInitialFragmentedFrame,
257                    });
258                }
259                if header.fin {
260                    *fragmented = false;
261                }
262                Ok(())
263            }
264            OpCode::Binary => {
265                if *fragmented {
266                    return Err(WsError::ProtocolError {
267                        close_code: 1002,
268                        error: ProtocolError::NotContinueFrameAfterFragmented,
269                    });
270                }
271                *fragmented = !header.fin;
272                Ok(())
273            }
274            OpCode::Text => {
275                if *fragmented {
276                    return Err(WsError::ProtocolError {
277                        close_code: 1002,
278                        error: ProtocolError::NotContinueFrameAfterFragmented,
279                    });
280                }
281                if !header.fin {
282                    *fragmented = true;
283                    if header.code == OpCode::Text
284                        && utf8_policy.is_fast_fail()
285                        && simdutf8::basic::from_utf8(payload).is_err()
286                    {
287                        return Err(WsError::ProtocolError {
288                            close_code: 1007,
289                            error: ProtocolError::InvalidUtf8,
290                        });
291                    }
292
293                    Ok(())
294                } else {
295                    if header.code == OpCode::Text
296                        && utf8_policy.should_check()
297                        && simdutf8::basic::from_utf8(payload).is_err()
298                    {
299                        return Err(WsError::ProtocolError {
300                            close_code: 1007,
301                            error: ProtocolError::InvalidUtf8,
302                        });
303                    }
304                    Ok(())
305                }
306            }
307            OpCode::Close | OpCode::Ping | OpCode::Pong => {
308                if !header.fin {
309                    return Err(WsError::ProtocolError {
310                        close_code: 1002,
311                        error: ProtocolError::FragmentedControlFrame,
312                    });
313                }
314                let payload_len = payload.len();
315                if payload.len() > 125 {
316                    let error = ProtocolError::ControlFrameTooBig(payload_len);
317                    return Err(WsError::ProtocolError {
318                        close_code: 1002,
319                        error,
320                    });
321                }
322                if header.code == OpCode::Close {
323                    if payload_len == 1 {
324                        let error = ProtocolError::InvalidCloseFramePayload;
325                        return Err(WsError::ProtocolError {
326                            close_code: 1002,
327                            error,
328                        });
329                    }
330                    if payload_len >= 2 {
331                        // check close code
332                        let mut code_byte = [0u8; 2];
333                        code_byte.copy_from_slice(&payload[..2]);
334                        let code = u16::from_be_bytes(code_byte);
335                        if code < 1000
336                            || (1004..=1006).contains(&code)
337                            || (1015..=2999).contains(&code)
338                            || code >= 5000
339                        {
340                            let error = ProtocolError::InvalidCloseCode(code);
341                            return Err(WsError::ProtocolError {
342                                close_code: 1002,
343                                error,
344                            });
345                        }
346
347                        // utf-8 validation
348                        if String::from_utf8(payload[2..].to_vec()).is_err() {
349                            let error = ProtocolError::InvalidUtf8;
350                            return Err(WsError::ProtocolError {
351                                close_code: 1007,
352                                error,
353                            });
354                        }
355                    }
356                }
357                Ok(())
358            }
359            _ => Err(WsError::UnsupportedFrame(header.code)),
360        }
361    }
362
363    /// This method is technically private, but custom parsers are allowed to use it.
364    #[doc(hidden)]
365    #[inline]
366    pub fn merge_frame(
367        &mut self,
368        header: SimplifiedHeader,
369        range: Range<usize>,
370    ) -> Result<Option<bool>, WsError> {
371        let fragmented = &mut self.fragmented;
372        let fragmented_data = &mut self.fragmented_data;
373        let fragmented_type = &mut self.fragmented_type;
374        let payload = &self.buf.buf[range];
375        match header.code {
376            OpCode::Continue => {
377                fragmented_data.extend_from_slice(payload);
378                if header.fin {
379                    *fragmented = false;
380                    Ok(Some(true))
381                } else {
382                    Ok(None)
383                }
384            }
385            OpCode::Text | OpCode::Binary => {
386                *fragmented_type = header.code;
387                if !header.fin {
388                    *fragmented = true;
389                    *fragmented_type = header.code;
390                    fragmented_data.clear();
391                    fragmented_data.extend_from_slice(payload);
392                    Ok(None)
393                } else {
394                    Ok(Some(false))
395                }
396            }
397            OpCode::Close | OpCode::Ping | OpCode::Pong => Ok(Some(false)),
398            _ => unreachable!(),
399        }
400    }
401}
402
403pub(crate) struct FrameBuffer {
404    pub(crate) buf: Vec<u8>,
405    tmp: Vec<u8>,
406    produce_idx: usize,
407    consume_idx: usize,
408}
409
410impl FrameBuffer {
411    pub(crate) fn new() -> Self {
412        Self {
413            buf: vec![0; 8192],
414            tmp: vec![0; 8192],
415            produce_idx: 0,
416            consume_idx: 0,
417        }
418    }
419
420    pub(crate) fn prepare(&mut self, payload_size: usize) -> &mut [u8] {
421        let remain = self.buf.len() - self.produce_idx;
422        if remain >= payload_size {
423            &mut self.buf[self.produce_idx..(self.produce_idx + payload_size)]
424        } else {
425            if self.produce_idx == self.consume_idx {
426                if payload_size > self.buf.len() {
427                    self.buf.resize(payload_size, 0);
428                }
429                self.consume_idx = 0;
430                self.produce_idx = 0;
431                &mut self.buf[0..payload_size]
432            } else {
433                self.tmp.resize(self.produce_idx - self.consume_idx, 0);
434                self.tmp
435                    .copy_from_slice(&self.buf[self.consume_idx..self.produce_idx]);
436                if payload_size + self.tmp.len() > self.buf.len() {
437                    self.buf.resize(payload_size + self.tmp.len(), 0);
438                }
439                self.buf[..(self.tmp.len())].copy_from_slice(&self.tmp);
440                self.consume_idx = 0;
441                self.produce_idx = self.tmp.len();
442                &mut self.buf[self.produce_idx..(self.produce_idx + payload_size)]
443            }
444        }
445    }
446
447    pub(crate) fn ava_data(&self) -> &[u8] {
448        &self.buf[self.consume_idx..self.produce_idx]
449    }
450
451    pub(crate) fn ava_mut_data(&mut self) -> &mut [u8] {
452        &mut self.buf[self.consume_idx..self.produce_idx]
453    }
454
455    pub(crate) fn produce(&mut self, num: usize) {
456        self.produce_idx += num;
457    }
458
459    pub(crate) fn consume(&mut self, num: usize) {
460        self.consume_idx += num;
461    }
462}
463
464/// websocket writing state
465#[allow(dead_code)]
466#[derive(Debug, Clone, Default)]
467pub struct FrameWriteState {
468    config: FrameConfig,
469    header_buf: [u8; 14],
470    buf: BytesMut,
471}
472
473impl FrameWriteState {
474    /// construct with config
475    pub fn with_config(config: FrameConfig) -> Self {
476        Self {
477            config,
478            header_buf: [0; 14],
479            buf: BytesMut::new(),
480        }
481    }
482}
483
484/// do standard handshake check and return response
485pub fn default_handshake_handler(
486    req: http::Request<()>,
487) -> Result<(http::Request<()>, http::Response<String>), (http::Response<String>, WsError)> {
488    match standard_handshake_req_check(&req) {
489        Ok(_) => {
490            let key = req.headers().get("sec-websocket-key").unwrap();
491            let resp = http::Response::builder()
492                .version(http::Version::HTTP_11)
493                .status(http::StatusCode::SWITCHING_PROTOCOLS)
494                .header("Upgrade", "WebSocket")
495                .header("Connection", "Upgrade")
496                .header("Sec-WebSocket-Accept", cal_accept_key(key.as_bytes()))
497                .body(String::new())
498                .unwrap();
499            Ok((req, resp))
500        }
501        Err(e) => {
502            let resp = http::Response::builder()
503                .version(http::Version::HTTP_11)
504                .status(http::StatusCode::BAD_REQUEST)
505                .header("Content-Type", "text/html")
506                .body(e.to_string())
507                .unwrap();
508            Err((resp, e))
509        }
510    }
511}