Skip to main content

wtx/web_socket/
misc.rs

1use crate::{
2  misc::{ConnectionState, FnMutFut, from_utf8_basic},
3  rng::Rng,
4  stream::StreamWriter,
5  web_socket::{
6    CloseCode, Frame, MASK_MASK, MAX_CONTROL_PAYLOAD_LEN, MAX_HEADER_LEN, OP_CODE_MASK, OpCode,
7    WebSocketError, web_socket_writer::manage_normal_frame,
8  },
9};
10
11/// Copies `frame_code` and `frame_payload` into `buffer`.
12#[inline]
13pub fn fill_buffer_with_close_frame(
14  buffer: &mut [u8],
15  frame_code: CloseCode,
16  frame_payload: &[u8],
17) -> crate::Result<()> {
18  let rest = fill_buffer_with_close_code(buffer, frame_code);
19  let Some(slice) = rest.and_then(|el| el.get_mut(..frame_payload.len())) else {
20    return Err(WebSocketError::InvalidCloseFrameParams.into());
21  };
22  slice.copy_from_slice(frame_payload);
23  Ok(())
24}
25
26/// The first two bytes of `buffer` are filled with `code`. Does nothing if `buffer` is
27/// less than 2 bytes.
28#[inline]
29pub fn fill_buffer_with_close_code(buffer: &mut [u8], code: CloseCode) -> Option<&mut [u8]> {
30  let [a, b, rest @ ..] = buffer else {
31    return None;
32  };
33  let [c, d] = u16::from(code).to_be_bytes();
34  *a = c;
35  *b = d;
36  Some(rest)
37}
38
39/// Returns `true` if `payload` is greater than the maximum allowed length.
40#[inline]
41pub(crate) fn check_read_close_frame(
42  connection_state: &mut ConnectionState,
43  payload: &[u8],
44) -> crate::Result<bool> {
45  if connection_state.is_closed() {
46    return Err(crate::Error::ClosedWebSocketConnection);
47  }
48  *connection_state = ConnectionState::Closed;
49  match payload {
50    [] => Ok(false),
51    [_] => Err(WebSocketError::InvalidCloseFrame.into()),
52    [a, b, rest @ ..] => {
53      let _str_validation = from_utf8_basic(rest)?;
54      let close_code = CloseCode::try_from(u16::from_be_bytes([*a, *b]))?;
55      if !close_code.is_allowed() || rest.len() > MAX_CONTROL_PAYLOAD_LEN - 2 {
56        Ok(true)
57      } else {
58        Ok(false)
59      }
60    }
61  }
62}
63
64pub(crate) fn control_frame_payload(data: &[u8]) -> ([u8; MAX_CONTROL_PAYLOAD_LEN], u8) {
65  let len = data.len().min(MAX_CONTROL_PAYLOAD_LEN);
66  let mut array = [0; MAX_CONTROL_PAYLOAD_LEN];
67  let slice = array.get_mut(..len).unwrap_or_default();
68  slice.copy_from_slice(data.get(..len).unwrap_or_default());
69  (array, len.try_into().unwrap_or_default())
70}
71
72pub(crate) fn fill_header_from_params<const IS_CLIENT: bool>(
73  fin: bool,
74  header: &mut [u8; MAX_HEADER_LEN],
75  op_code: OpCode,
76  payload_len: usize,
77  rsv1: u8,
78) -> u8 {
79  fn first_header_byte(fin: bool, op_code: OpCode, rsv1: u8) -> u8 {
80    (u8::from(fin) << 7) | rsv1 | u8::from(op_code)
81  }
82
83  match payload_len {
84    0..=125 => {
85      let [a, b, ..] = header;
86      *a = first_header_byte(fin, op_code, rsv1);
87      *b = u8::try_from(payload_len).unwrap_or_default();
88      2
89    }
90    126..=65535 => {
91      let [len_c, len_d] = u16::try_from(payload_len).map(u16::to_be_bytes).unwrap_or_default();
92      let [a, b, c, d, ..] = header;
93      *a = first_header_byte(fin, op_code, rsv1);
94      *b = 126;
95      *c = len_c;
96      *d = len_d;
97      4
98    }
99    _ => {
100      let len = u64::try_from(payload_len).map(u64::to_be_bytes).unwrap_or_default();
101      let [len_c, len_d, len_e, len_f, len_g, len_h, len_i, len_j] = len;
102      let [a, b, c, d, e, f, g, h, i, j, ..] = header;
103      *a = first_header_byte(fin, op_code, rsv1);
104      *b = 127;
105      *c = len_c;
106      *d = len_d;
107      *e = len_e;
108      *f = len_f;
109      *g = len_g;
110      *h = len_h;
111      *i = len_i;
112      *j = len_j;
113      10
114    }
115  }
116}
117
118pub(crate) const fn has_masked_frame(second_header_byte: u8) -> bool {
119  second_header_byte & MASK_MASK != 0
120}
121
122pub(crate) fn op_code(first_header_byte: u8) -> crate::Result<OpCode> {
123  OpCode::try_from(first_header_byte & OP_CODE_MASK)
124}
125
126pub(crate) async fn write_control_frame<A, RNG, const IS_CLIENT: bool>(
127  aux: A,
128  connection_state: &mut ConnectionState,
129  no_masking: bool,
130  op_code: OpCode,
131  payload: &mut [u8],
132  rng: &mut RNG,
133  mut wsc_cb: impl for<'any> FnMutFut<(A, &'any [u8], &'any [u8]), Result = crate::Result<()>>,
134) -> crate::Result<()>
135where
136  RNG: Rng,
137{
138  let mut frame = Frame::<_, IS_CLIENT>::new_fin(op_code, payload);
139  manage_normal_frame(connection_state, &mut frame, no_masking, rng);
140  wsc_cb.call((aux, frame.header(), frame.payload())).await?;
141  Ok(())
142}
143
144pub(crate) async fn write_control_frame_cb<SW>(
145  stream: &mut SW,
146  header: &[u8],
147  payload: &[u8],
148) -> crate::Result<()>
149where
150  SW: StreamWriter,
151{
152  stream.write_all_vectored(&[header, payload]).await?;
153  Ok(())
154}