1use std::convert::TryFrom;
2use std::mem;
3use std::usize;
4
5use byteorder::{BigEndian, ByteOrder, NativeEndian};
6use bytes::BytesMut;
7use tokio_util::codec::{Decoder, Encoder};
8
9use crate::mask::Mask;
10use crate::{Error, Result};
11
12#[derive(Copy, Clone, Debug, PartialEq)]
14pub enum DataLength {
15 Small(u8),
17 Medium(u16),
19 Large(u64),
21}
22
23impl From<u64> for DataLength {
24 fn from(n: u64) -> Self {
25 if n <= 125 {
26 Self::Small(n as u8)
27 } else if n <= 65535 {
28 Self::Medium(n as u16)
29 } else {
30 Self::Large(n)
31 }
32 }
33}
34
35impl TryFrom<DataLength> for u64 {
36 type Error = Error;
37
38 fn try_from(len: DataLength) -> Result<Self> {
39 match len {
40 DataLength::Small(n) => Ok(n as u64),
41 DataLength::Medium(n) => {
42 if n <= 125 {
43 return Err(format!("payload length {} should not be represented using 16 bits", n).into());
44 }
45
46 Ok(n as u64)
47 }
48 DataLength::Large(n) => {
49 if n <= 65535 {
50 return Err(format!("payload length {} should not be represented using 64 bits", n).into());
51 }
52
53 if n >= 0x8000_0000_0000_0000 {
54 return Err(format!("frame is too long: {} bytes ({:x})", n, n).into());
55 }
56
57 Ok(n as u64)
58 }
59 }
60 }
61}
62
63impl From<usize> for DataLength {
64 fn from(n: usize) -> Self {
65 Self::from(n as u64)
66 }
67}
68
69impl TryFrom<DataLength> for usize {
70 type Error = Error;
71
72 fn try_from(len: DataLength) -> Result<Self> {
73 let len = u64::try_from(len)?;
74 if len > usize::MAX as u64 {
75 return Err(format!(
76 "frame of {} bytes can't be parsed on a {}-bit platform",
77 len,
78 mem::size_of::<usize>() / 8
79 )
80 .into());
81 }
82
83 Ok(len as usize)
84 }
85}
86
87#[derive(Clone, Debug, PartialEq)]
92pub struct FrameHeader {
93 pub(crate) fin: bool,
94 pub(crate) rsv: u8,
95 pub(crate) opcode: u8,
96 pub(crate) mask: Option<Mask>,
97 pub(crate) data_len: DataLength,
98}
99
100impl FrameHeader {
101 pub fn new(fin: bool, rsv: u8, opcode: u8, mask: Option<Mask>, data_len: DataLength) -> Self {
103 Self {
104 fin,
105 rsv,
106 opcode,
107 mask,
108 data_len,
109 }
110 }
111
112 pub fn fin(&self) -> bool {
114 self.fin
115 }
116
117 pub fn rsv(&self) -> u8 {
121 self.rsv
122 }
123
124 pub fn opcode(&self) -> u8 {
126 self.opcode
127 }
128
129 pub fn mask(&self) -> Option<Mask> {
131 self.mask
132 }
133
134 pub fn data_len(&self) -> DataLength {
136 self.data_len
137 }
138
139 pub fn header_len(&self) -> usize {
144 let mut len = 1 + 1 ;
145 len += match self.data_len {
146 DataLength::Small(_) => 0,
147 DataLength::Medium(_) => 2,
148 DataLength::Large(_) => 8,
149 };
150
151 if self.mask.is_some() {
152 len += 4;
153 }
154
155 len
156 }
157
158 pub(crate) fn parse_slice(buf: &[u8]) -> Option<(Self, usize)> {
159 if buf.len() < 2 {
160 return None;
161 }
162
163 let fin_opcode = buf[0];
164 let mask_data_len = buf[1];
165 let mut header_len = 2;
166 let fin = (fin_opcode & 0x80) != 0;
167 let rsv = (fin_opcode & 0xf0) & !0x80;
168 let opcode = fin_opcode & 0x0f;
169
170 let (buf, data_len) = match mask_data_len & 0x7f {
171 127 => {
172 if buf.len() < 10 {
173 return None;
174 }
175
176 header_len += 8;
177
178 (&buf[10..], DataLength::Large(BigEndian::read_u64(&buf[2..10])))
179 }
180 126 => {
181 if buf.len() < 4 {
182 return None;
183 }
184
185 header_len += 2;
186
187 (&buf[4..], DataLength::Medium(BigEndian::read_u16(&buf[2..4])))
188 }
189 n => {
190 assert!(n < 126);
191 (&buf[2..], DataLength::Small(n))
192 }
193 };
194
195 let mask = if mask_data_len & 0x80 == 0 {
196 None
197 } else {
198 if buf.len() < 4 {
199 return None;
200 }
201
202 header_len += 4;
203 Some(NativeEndian::read_u32(buf).into())
204 };
205
206 let header = Self {
207 fin,
208 rsv,
209 opcode,
210 mask,
211 data_len,
212 };
213
214 debug_assert_eq!(header.header_len(), header_len);
215 Some((header, header_len))
216 }
217
218 pub(crate) fn write_to_slice(&self, dst: &mut [u8]) {
219 let FrameHeader {
220 fin,
221 rsv,
222 opcode,
223 mask,
224 data_len,
225 } = *self;
226
227 let mut fin_opcode = rsv | opcode;
228 if fin {
229 fin_opcode |= 0x80
230 };
231
232 dst[0] = fin_opcode;
233
234 let mask_bit = if mask.is_some() { 0x80 } else { 0 };
235
236 let dst = match data_len {
237 DataLength::Small(n) => {
238 dst[1] = mask_bit | n;
239 &mut dst[2..]
240 }
241 DataLength::Medium(n) => {
242 let (dst, rest) = dst.split_at_mut(4);
243 dst[1] = mask_bit | 126;
244 BigEndian::write_u16(&mut dst[2..4], n);
245 rest
246 }
247 DataLength::Large(n) => {
248 let (dst, rest) = dst.split_at_mut(10);
249 dst[1] = mask_bit | 127;
250 BigEndian::write_u64(&mut dst[2..10], n);
251 rest
252 }
253 };
254
255 if let Some(mask) = mask {
256 NativeEndian::write_u32(dst, mask.into());
257 }
258 }
259
260 pub(crate) fn write_to_bytes(&self, dst: &mut BytesMut) {
261 let data_len = match self.data_len {
262 DataLength::Small(n) => n as usize,
263 DataLength::Medium(n) => n as usize,
264 DataLength::Large(n) => n as usize,
265 };
266
267 let initial_len = dst.len();
268 let header_len = self.header_len();
269 dst.reserve(header_len + data_len);
270
271 unsafe {
272 dst.set_len(initial_len + header_len);
273 }
274
275 let dst_slice = &mut dst[initial_len..(initial_len + header_len)];
276 self.write_to_slice(dst_slice);
277 }
278}
279
280pub struct FrameHeaderCodec;
286
287impl Decoder for FrameHeaderCodec {
288 type Item = FrameHeader;
289 type Error = Error;
290
291 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrameHeader>> {
292 use bytes::Buf;
293
294 Ok(FrameHeader::parse_slice(src.chunk()).map(|(header, header_len)| {
295 src.advance(header_len);
296 header
297 }))
298 }
299}
300
301impl Encoder<FrameHeader> for FrameHeaderCodec {
302 type Error = Error;
303
304 fn encode(&mut self, item: FrameHeader, dst: &mut BytesMut) -> Result<()> {
305 self.encode(&item, dst)
306 }
307}
308
309impl<'a> Encoder<&'a FrameHeader> for FrameHeaderCodec {
310 type Error = Error;
311
312 fn encode(&mut self, item: &'a FrameHeader, dst: &mut BytesMut) -> Result<()> {
313 item.write_to_bytes(dst);
314 Ok(())
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use assert_allocations::assert_allocated_bytes;
321 use bytes::BytesMut;
322 use tokio_util::codec::{Decoder, Encoder};
323
324 use crate::frame::{FrameHeader, FrameHeaderCodec};
325
326 #[quickcheck]
327 fn round_trips(fin: bool, is_text: bool, mask: Option<u32>, data_len: u16) {
328 let header = assert_allocated_bytes(0, || FrameHeader {
329 fin,
330 rsv: 0,
331 opcode: if is_text { 1 } else { 2 },
332 mask: mask.map(|n| n.into()),
333 data_len: (data_len as u64).into(),
334 });
335
336 assert_allocated_bytes((header.header_len() + data_len as usize).max(8), || {
337 let mut codec = FrameHeaderCodec;
338 let mut bytes = BytesMut::new();
339 codec.encode(&header, &mut bytes).unwrap();
340 let header_len = header.header_len();
341 assert_eq!(bytes.len(), header_len);
342
343 let header2 = codec.decode(&mut bytes).unwrap().unwrap();
344 assert_eq!(header2.header_len(), header_len);
345 assert_eq!(bytes.len(), 0);
346 assert_eq!(header, header2)
347 })
348 }
349}