1use crate::codec::apply_mask;
2use bytes::{BufMut, BytesMut};
3use std::fmt::Debug;
4
5#[derive(Debug, Clone, PartialEq, Eq, Copy)]
17#[repr(u8)]
18pub enum OpCode {
19 Continue = 0,
21 Text = 1,
23 Binary = 2,
25 RNC3 = 3,
27 RNC4 = 4,
29 RNC5 = 5,
31 RNC6 = 6,
33 RNC7 = 7,
35 Close = 8,
37 Ping = 9,
39 Pong = 10,
41 RC11 = 11,
43 RC12 = 12,
45 RC13 = 13,
47 RC14 = 14,
49 RC15 = 15,
51}
52
53impl Default for OpCode {
54 fn default() -> Self {
55 Self::Text
56 }
57}
58
59impl OpCode {
60 pub fn as_u8(&self) -> u8 {
62 *self as u8
63 }
64
65 pub fn is_close(&self) -> bool {
67 matches!(self, Self::Close)
68 }
69
70 pub fn is_data(&self) -> bool {
72 matches!(self, Self::Text | Self::Binary | Self::Continue)
73 }
74
75 pub fn is_reserved(&self) -> bool {
77 matches!(self.as_u8(), 3..=5 | 11..=15)
78 }
79}
80
81#[inline]
82pub(crate) fn parse_opcode(val: u8) -> OpCode {
83 unsafe { std::mem::transmute(val & 0b00001111) }
84}
85
86#[inline]
87pub(crate) fn get_bit(source: &[u8], byte_idx: usize, bit_idx: u8) -> bool {
88 let mask = match bit_idx {
89 0 => 128,
90 1 => 64,
91 2 => 32,
92 3 => 16,
93 4 => 8,
94 5 => 4,
95 6 => 2,
96 7 => 1,
97 _ => unreachable!(),
98 };
99 unsafe { *source.get_unchecked(byte_idx) & mask == mask }
100}
101
102#[inline]
103pub(crate) fn set_bit(source: &mut [u8], byte_idx: usize, bit_idx: u8, val: bool) {
104 if val {
105 let mask = match bit_idx {
106 0 => 128,
107 1 => 64,
108 2 => 32,
109 3 => 16,
110 4 => 8,
111 5 => 4,
112 6 => 2,
113 7 => 1,
114 _ => unreachable!(),
115 };
116 source[byte_idx] |= mask;
117 } else {
118 let mask = match bit_idx {
119 0 => 0b01111111,
120 1 => 0b10111111,
121 2 => 0b11011111,
122 3 => 0b11101111,
123 4 => 0b11110111,
124 5 => 0b11111011,
125 6 => 0b11111101,
126 7 => 0b11111110,
127 _ => unreachable!(),
128 };
129 source[byte_idx] &= mask;
130 }
131}
132
133macro_rules! impl_get {
134 () => {
135 #[inline]
136 fn get_bit(&self, byte_idx: usize, bit_idx: u8) -> bool {
137 get_bit(&self.0, byte_idx, bit_idx)
138 }
139
140 #[inline]
142 pub fn fin(&self) -> bool {
143 self.get_bit(0, 0)
144 }
145
146 #[inline]
148 pub fn rsv1(&self) -> bool {
149 self.get_bit(0, 1)
150 }
151
152 #[inline]
154 pub fn rsv2(&self) -> bool {
155 self.get_bit(0, 2)
156 }
157
158 #[inline]
160 pub fn rsv3(&self) -> bool {
161 self.get_bit(0, 3)
162 }
163
164 #[inline]
166 pub fn opcode(&self) -> OpCode {
167 parse_opcode(unsafe { *self.0.get_unchecked(0) })
168 }
169
170 #[inline]
172 pub fn masked(&self) -> bool {
173 self.get_bit(1, 0)
174 }
175
176 #[inline]
177 fn len_bytes(&self) -> usize {
178 let header = &self.0;
179 match header[1] {
180 0..=125 | 128..=253 => 1,
181 126 | 254 => 3,
182 127 | 255 => 9,
183 }
184 }
185
186 #[inline]
188 pub fn payload_len(&self) -> u64 {
189 let header = &self.0;
190 assert!(header.len() >= 1);
191 match header[1] {
192 len @ (0..=125 | 128..=253) => (len & 127) as u64,
193 126 | 254 => {
194 assert!(header.len() >= 4);
195 u16::from_be_bytes((&header[2..4]).try_into().unwrap()) as u64
196 }
197 127 | 255 => {
198 assert!(header.len() >= 10);
199 u64::from_be_bytes((&header[2..(8 + 2)]).try_into().unwrap())
200 }
201 }
202 }
203
204 #[inline]
206 pub fn masking_key(&self) -> Option<[u8; 4]> {
207 if self.masked() {
208 let len_occupied = self.len_bytes();
209 let mut arr = [0u8; 4];
210 arr.copy_from_slice(&self.0[(1 + len_occupied)..(5 + len_occupied)]);
211 Some(arr)
212 } else {
213 None
214 }
215 }
216 };
217}
218
219pub fn header_len(mask: bool, payload_len: u64) -> usize {
221 let mut header_len = 1;
222 if mask {
223 header_len += 4;
224 }
225 if payload_len <= 125 {
226 header_len += 1;
227 } else if payload_len <= 65535 {
228 header_len += 3;
229 } else {
230 header_len += 9;
231 }
232 header_len
233}
234
235#[inline]
236const fn first_byte(fin: bool, rsv1: bool, rsv2: bool, rsv3: bool, opcode: OpCode) -> u8 {
237 let leading = match (fin, rsv1, rsv2, rsv3) {
238 (true, true, true, true) => 0b1111_0000,
239 (true, true, true, false) => 0b1110_0000,
240 (true, true, false, true) => 0b1101_0000,
241 (true, true, false, false) => 0b1100_0000,
242 (true, false, true, true) => 0b1011_0000,
243 (true, false, true, false) => 0b1010_0000,
244 (true, false, false, true) => 0b1001_0000,
245 (true, false, false, false) => 0b1000_0000,
246 (false, true, true, true) => 0b0111_0000,
247 (false, true, true, false) => 0b0110_0000,
248 (false, true, false, true) => 0b0101_0000,
249 (false, true, false, false) => 0b0100_0000,
250 (false, false, true, true) => 0b0011_0000,
251 (false, false, true, false) => 0b0010_0000,
252 (false, false, false, true) => 0b0001_0000,
253 (false, false, false, false) => 0b0000_0000,
254 };
255 leading | opcode as u8
256}
257
258#[allow(clippy::too_many_arguments)]
260pub fn ctor_header<M: Into<Option<[u8; 4]>>>(
261 buf: &mut [u8],
262 fin: bool,
263 rsv1: bool,
264 rsv2: bool,
265 rsv3: bool,
266 mask_key: M,
267 opcode: OpCode,
268 payload_len: u64,
269) -> &[u8] {
270 let mask = mask_key.into();
271 let mut header_len = 1;
272 if mask.is_some() {
273 header_len += 4;
274 }
275 if payload_len <= 125 {
276 buf[1] = payload_len as u8;
277 header_len += 1;
278 } else if payload_len <= 65535 {
279 buf[1] = 126;
280 buf[2..4].copy_from_slice(&(payload_len as u16).to_be_bytes());
281 header_len += 3;
282 } else {
283 buf[1] = 127;
284 buf[2..10].copy_from_slice(&payload_len.to_be_bytes());
285 header_len += 9;
286 }
287 buf[0] = first_byte(fin, rsv1, rsv2, rsv3, opcode);
288 if let Some(key) = mask {
289 set_bit(buf, 1, 0, true);
290 buf[(header_len - 4)..header_len].copy_from_slice(&key);
291 } else {
292 set_bit(buf, 1, 0, false);
293 }
294 &buf[..header_len]
295}
296
297#[test]
298fn test_header() {
299 fn rand_mask() -> Option<[u8; 4]> {
300 fastrand::bool().then(|| fastrand::u32(0..u32::MAX).to_be_bytes())
301 }
302
303 fn rand_code() -> OpCode {
304 unsafe { std::mem::transmute(fastrand::u8(0..16)) }
305 }
306
307 let mut buf = [0u8; 14];
308 for _ in 0..1000 {
309 let fin = fastrand::bool();
310 let rsv1 = fastrand::bool();
311 let rsv2 = fastrand::bool();
312 let rsv3 = fastrand::bool();
313 let mask_key = rand_mask();
314 let opcode = rand_code();
315 let payload_len = fastrand::u64(0..u64::MAX);
316
317 let slice = ctor_header(
318 &mut buf,
319 fin,
320 rsv1,
321 rsv2,
322 rsv3,
323 mask_key,
324 opcode,
325 payload_len,
326 );
327 let header = Header::new(fin, rsv1, rsv2, rsv3, mask_key, opcode, payload_len);
328 assert_eq!(slice, &header.0.to_vec());
329 }
330}
331
332#[derive(Debug, Clone, Copy)]
334pub struct SimplifiedHeader {
335 pub fin: bool,
337 pub rsv1: bool,
339 pub rsv2: bool,
341 pub rsv3: bool,
343 pub code: OpCode,
345}
346
347impl<'a> From<HeaderView<'a>> for SimplifiedHeader {
348 fn from(value: HeaderView<'a>) -> Self {
349 Self {
350 fin: value.fin(),
351 rsv1: value.rsv1(),
352 rsv2: value.rsv2(),
353 rsv3: value.rsv3(),
354 code: value.opcode(),
355 }
356 }
357}
358
359#[derive(Debug, Clone, Copy)]
361pub struct HeaderView<'a>(pub(crate) &'a [u8]);
362
363impl<'a> HeaderView<'a> {
364 impl_get! {}
365}
366
367#[derive(Debug, Clone)]
369pub struct Header(pub(crate) BytesMut);
370
371impl Header {
372 impl_get! {}
373 pub fn as_bytes(&self) -> &[u8] {
375 &self.0
376 }
377
378 #[inline]
379 fn set_bit(&mut self, byte_idx: usize, bit_idx: u8, val: bool) {
380 set_bit(&mut self.0, byte_idx, bit_idx, val)
381 }
382
383 #[inline]
385 pub fn set_fin(&mut self, val: bool) {
386 self.set_bit(0, 0, val)
387 }
388
389 #[inline]
391 pub fn set_rsv1(&mut self, val: bool) {
392 self.set_bit(0, 1, val)
393 }
394
395 #[inline]
397 pub fn set_rsv2(&mut self, val: bool) {
398 self.set_bit(0, 2, val)
399 }
400
401 #[inline]
403 pub fn set_rsv3(&mut self, val: bool) {
404 self.set_bit(0, 3, val)
405 }
406
407 #[inline]
409 pub fn set_opcode(&mut self, code: OpCode) {
410 let header = &mut self.0;
411 let leading_bits = (header[0] >> 4) << 4;
412 header[0] = leading_bits | code.as_u8()
413 }
414
415 #[inline]
418 pub fn set_mask(&mut self, mask: bool) {
419 self.set_bit(1, 0, mask);
420 }
421
422 #[inline]
425 pub fn set_payload_len(&mut self, len: u64) {
426 let mask = self.masking_key();
427 let mask_len = mask.as_ref().map(|_| 4).unwrap_or_default();
428 let header = &mut self.0;
429 let mut leading_byte = header[1];
430 match len {
431 0..=125 => {
432 leading_byte &= 128;
433 header[1] = leading_byte | (len as u8);
434 let idx = 1 + 1;
435 header.resize(idx + mask_len, 0);
436 if let Some(mask) = mask {
437 header[idx..].copy_from_slice(&mask);
438 }
439 }
440 126..=65535 => {
441 leading_byte &= 128;
442 header[1] = leading_byte | 126;
443 let len_arr = (len as u16).to_be_bytes();
444 let idx = 1 + 3;
445 header.resize(idx + mask_len, 0);
446 header[2] = len_arr[0];
447 header[3] = len_arr[1];
448 if let Some(mask) = mask {
449 header[idx..].copy_from_slice(&mask);
450 }
451 }
452 _ => {
453 leading_byte &= 128;
454 header[1] = leading_byte | 127;
455 let len_arr = len.to_be_bytes();
456 let idx = 1 + 9;
457 header.resize(idx + mask_len, 0);
458 header[2..10].copy_from_slice(&len_arr[..8]);
459 if let Some(mask) = mask {
460 header[idx..].copy_from_slice(&mask);
461 }
462 }
463 }
464 }
465
466 pub fn raw(data: BytesMut) -> Self {
468 Self(data)
469 }
470
471 pub fn new<M: Into<Option<[u8; 4]>>>(
473 fin: bool,
474 rsv1: bool,
475 rsv2: bool,
476 rsv3: bool,
477 mask_key: M,
478 opcode: OpCode,
479 payload_len: u64,
480 ) -> Self {
481 let mask = mask_key.into();
482 let len = header_len(mask.is_some(), payload_len);
483 assert!(len >= 2);
484 let mut buf = BytesMut::zeroed(len);
485 buf[0] = first_byte(fin, rsv1, rsv2, rsv3, opcode);
486 let mut header = Self(buf);
487 header.set_mask(mask.is_some());
488 header.set_payload_len(payload_len);
489 if let Some(mask) = mask {
490 header.0[(len - 4)..len].copy_from_slice(&mask);
491 }
492 header
493 }
494}
495
496#[derive(Debug, Clone)]
498pub struct OwnedFrame {
499 pub(crate) header: Header,
500 pub(crate) payload: BytesMut,
501}
502
503impl OwnedFrame {
504 #[inline]
506 pub fn new(code: OpCode, mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
507 let header = Header::new(true, false, false, false, mask, code, data.len() as u64);
508 let mut payload = BytesMut::with_capacity(data.len());
509 payload.extend_from_slice(data);
510 if let Some(mask) = header.masking_key() {
511 apply_mask(&mut payload, mask);
512 }
513 Self { header, payload }
514 }
515
516 #[inline]
520 pub fn with_raw(header: Header, payload: BytesMut) -> Self {
521 Self { header, payload }
522 }
523
524 #[inline]
526 pub fn text_frame(mask: impl Into<Option<[u8; 4]>>, data: &str) -> Self {
527 Self::new(OpCode::Text, mask, data.as_bytes())
528 }
529
530 #[inline]
532 pub fn binary_frame(mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
533 Self::new(OpCode::Binary, mask, data)
534 }
535
536 #[inline]
538 pub fn ping_frame(mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
539 assert!(data.len() <= 125);
540 Self::new(OpCode::Ping, mask, data)
541 }
542
543 #[inline]
545 pub fn pong_frame(mask: impl Into<Option<[u8; 4]>>, data: &[u8]) -> Self {
546 assert!(data.len() <= 125);
547 Self::new(OpCode::Pong, mask, data)
548 }
549
550 #[inline]
552 pub fn close_frame(
553 mask: impl Into<Option<[u8; 4]>>,
554 code: impl Into<Option<u16>>,
555 data: &[u8],
556 ) -> Self {
557 assert!(data.len() <= 123);
558 let code = code.into();
559 assert!(code.is_some() || data.is_empty());
560 let mut payload = BytesMut::with_capacity(2 + data.len());
561 if let Some(code) = code {
562 payload.put_u16(code);
563 payload.extend_from_slice(data);
564 }
565 Self::new(OpCode::Close, mask, &payload)
566 }
567
568 #[inline]
570 pub fn unmask(&mut self) -> Option<[u8; 4]> {
571 if let Some(mask) = self.header.masking_key() {
572 apply_mask(&mut self.payload, mask);
573 self.header.set_mask(false);
574 self.header.0.truncate(self.header.0.len() - 4);
575 Some(mask)
576 } else {
577 None
578 }
579 }
580
581 pub fn mask(&mut self, mask: [u8; 4]) {
585 self.unmask();
586 self.header.set_mask(true);
587 self.header.0.extend_from_slice(&mask);
588 apply_mask(&mut self.payload, mask);
589 }
590
591 pub fn extend_from_slice(&mut self, data: &[u8]) {
596 if let Some(mask) = self.unmask() {
597 self.payload.extend_from_slice(data);
598 self.header.set_payload_len(self.payload.len() as u64);
599 self.mask(mask);
600 } else {
601 self.payload.extend_from_slice(data);
602 self.header.set_payload_len(self.payload.len() as u64);
603 }
604 }
605
606 #[inline]
608 pub fn header(&self) -> &Header {
609 &self.header
610 }
611
612 #[inline]
614 pub fn header_mut(&mut self) -> &mut Header {
615 &mut self.header
616 }
617
618 #[inline]
620 pub fn payload(&self) -> &BytesMut {
621 &self.payload
622 }
623
624 #[inline]
626 pub fn parts(self) -> (Header, BytesMut) {
627 (self.header, self.payload)
628 }
629}