1use bytes::{Buf, BufMut, Bytes, BytesMut};
10
11use crate::error::{CloseReason, Error, Result};
12use crate::simd::apply_mask;
13use crate::utf8::validate_utf8;
14use crate::{MEDIUM_MESSAGE_THRESHOLD, SMALL_MESSAGE_THRESHOLD};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum OpCode {
20 Continuation = 0x0,
22 Text = 0x1,
24 Binary = 0x2,
26 Close = 0x8,
28 Ping = 0x9,
30 Pong = 0xA,
32}
33
34impl OpCode {
35 #[inline]
37 pub fn from_u8(byte: u8) -> Option<Self> {
38 match byte {
39 0x0 => Some(OpCode::Continuation),
40 0x1 => Some(OpCode::Text),
41 0x2 => Some(OpCode::Binary),
42 0x8 => Some(OpCode::Close),
43 0x9 => Some(OpCode::Ping),
44 0xA => Some(OpCode::Pong),
45 _ => None,
46 }
47 }
48
49 #[inline]
51 pub fn is_control(&self) -> bool {
52 (*self as u8) >= 0x8
53 }
54
55 #[inline]
57 pub fn is_data(&self) -> bool {
58 (*self as u8) <= 0x2
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct FrameHeader {
65 pub fin: bool,
67 pub rsv1: bool,
69 pub rsv2: bool,
71 pub rsv3: bool,
73 pub opcode: OpCode,
75 pub masked: bool,
77 pub payload_len: u64,
79 pub mask: Option<[u8; 4]>,
81}
82
83impl FrameHeader {
84 #[inline]
86 pub fn header_size(&self) -> usize {
87 let mut size = 2; if self.payload_len > MEDIUM_MESSAGE_THRESHOLD as u64 {
91 size += 8;
92 } else if self.payload_len > SMALL_MESSAGE_THRESHOLD as u64 {
93 size += 2;
94 }
95
96 if self.masked {
98 size += 4;
99 }
100
101 size
102 }
103
104 #[inline]
106 pub fn encode(&self, buf: &mut BytesMut) {
107 let mut b0 = self.opcode as u8;
109 if self.fin {
110 b0 |= 0x80;
111 }
112 if self.rsv1 {
113 b0 |= 0x40;
114 }
115 if self.rsv2 {
116 b0 |= 0x20;
117 }
118 if self.rsv3 {
119 b0 |= 0x10;
120 }
121 buf.put_u8(b0);
122
123 let mask_bit = if self.masked { 0x80 } else { 0x00 };
125
126 if self.payload_len <= SMALL_MESSAGE_THRESHOLD as u64 {
127 buf.put_u8(mask_bit | self.payload_len as u8);
128 } else if self.payload_len <= MEDIUM_MESSAGE_THRESHOLD as u64 {
129 buf.put_u8(mask_bit | 126);
130 buf.put_u16(self.payload_len as u16);
131 } else {
132 buf.put_u8(mask_bit | 127);
133 buf.put_u64(self.payload_len);
134 }
135
136 if let Some(mask) = self.mask {
138 buf.put_slice(&mask);
139 }
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct Frame {
146 pub header: FrameHeader,
148 pub payload: Bytes,
150}
151
152impl Frame {
153 pub fn new(opcode: OpCode, payload: Bytes, fin: bool) -> Self {
155 Self {
156 header: FrameHeader {
157 fin,
158 rsv1: false,
159 rsv2: false,
160 rsv3: false,
161 opcode,
162 masked: false,
163 payload_len: payload.len() as u64,
164 mask: None,
165 },
166 payload,
167 }
168 }
169
170 #[inline]
172 pub fn text(data: impl Into<Bytes>) -> Self {
173 Self::new(OpCode::Text, data.into(), true)
174 }
175
176 #[inline]
178 pub fn binary(data: impl Into<Bytes>) -> Self {
179 Self::new(OpCode::Binary, data.into(), true)
180 }
181
182 #[inline]
184 pub fn ping(data: impl Into<Bytes>) -> Self {
185 Self::new(OpCode::Ping, data.into(), true)
186 }
187
188 #[inline]
190 pub fn pong(data: impl Into<Bytes>) -> Self {
191 Self::new(OpCode::Pong, data.into(), true)
192 }
193
194 #[inline]
196 pub fn close(code: u16, reason: &str) -> Self {
197 let mut payload = BytesMut::with_capacity(2 + reason.len());
198 payload.put_u16(code);
199 payload.put_slice(reason.as_bytes());
200 Self::new(OpCode::Close, payload.freeze(), true)
201 }
202
203 #[inline]
205 pub fn close_empty() -> Self {
206 Self::new(OpCode::Close, Bytes::new(), true)
207 }
208
209 #[inline]
211 pub fn is_control(&self) -> bool {
212 self.header.opcode.is_control()
213 }
214
215 #[inline]
217 pub fn is_final(&self) -> bool {
218 self.header.fin
219 }
220
221 pub fn as_text(&self) -> Result<&str> {
223 if !validate_utf8(&self.payload) {
224 return Err(Error::InvalidUtf8);
225 }
226 Ok(unsafe { std::str::from_utf8_unchecked(&self.payload) })
228 }
229
230 pub fn parse_close(&self) -> Option<CloseReason> {
232 if self.payload.len() < 2 {
233 return None;
234 }
235 let code = u16::from_be_bytes([self.payload[0], self.payload[1]]);
236 let reason = if self.payload.len() > 2 {
237 String::from_utf8_lossy(&self.payload[2..]).into_owned()
238 } else {
239 String::new()
240 };
241 Some(CloseReason::new(code, reason))
242 }
243}
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
247enum ParseState {
248 Header,
250 ExtendedLen16,
252 ExtendedLen64,
254 Mask,
256 Payload,
258}
259
260pub struct FrameParser {
265 state: ParseState,
266 header_buf: [u8; 14],
268 header_len: usize,
269 header: Option<FrameHeader>,
271 max_frame_size: usize,
273 expect_masked: bool,
275 allow_rsv1: bool,
277}
278
279impl FrameParser {
280 pub fn new(max_frame_size: usize, expect_masked: bool) -> Self {
282 Self {
283 state: ParseState::Header,
284 header_buf: [0; 14],
285 header_len: 0,
286 header: None,
287 max_frame_size,
288 expect_masked,
289 allow_rsv1: false,
290 }
291 }
292
293 pub fn with_compression(max_frame_size: usize, expect_masked: bool) -> Self {
295 Self {
296 state: ParseState::Header,
297 header_buf: [0; 14],
298 header_len: 0,
299 header: None,
300 max_frame_size,
301 expect_masked,
302 allow_rsv1: true,
303 }
304 }
305
306 pub fn set_compression(&mut self, enabled: bool) {
308 self.allow_rsv1 = enabled;
309 }
310
311 #[inline]
313 fn reset(&mut self) {
314 self.state = ParseState::Header;
315 self.header_len = 0;
316 self.header = None;
317 }
318
319 #[inline]
326 pub fn parse(&mut self, buf: &mut BytesMut) -> Result<Option<Frame>> {
327 if self.state == ParseState::Header && !self.expect_masked && buf.len() >= 2 {
330 let b0 = buf[0];
331 let b1 = buf[1];
332 let len_byte = b1 & 0x7F;
333
334 if len_byte <= 125 && (b1 & 0x80) == 0 {
336 let payload_len = len_byte as usize;
337 let total_len = 2 + payload_len;
338
339 if buf.len() >= total_len {
340 let fin = b0 & 0x80 != 0;
342 let rsv1 = b0 & 0x40 != 0;
343 let rsv2 = b0 & 0x20 != 0;
344 let rsv3 = b0 & 0x10 != 0;
345
346 if (rsv1 && !self.allow_rsv1) || rsv2 || rsv3 {
348 return self.parse_slow(buf);
349 }
350
351 if let Some(opcode) = OpCode::from_u8(b0 & 0x0F) {
352 if opcode.is_control() && !fin {
354 return Err(Error::Protocol("control frame must not be fragmented"));
355 }
356
357 buf.advance(2);
359 let payload = buf.split_to(payload_len).freeze();
360
361 return Ok(Some(Frame {
362 header: FrameHeader {
363 fin,
364 rsv1,
365 rsv2,
366 rsv3,
367 opcode,
368 masked: false,
369 payload_len: payload_len as u64,
370 mask: None,
371 },
372 payload,
373 }));
374 }
375 }
376 }
377 }
378
379 if self.state == ParseState::Header && self.expect_masked && buf.len() >= 6 {
382 let b0 = buf[0];
383 let b1 = buf[1];
384 let len_byte = b1 & 0x7F;
385
386 if len_byte <= 125 && (b1 & 0x80) != 0 {
388 let payload_len = len_byte as usize;
389 let total_len = 2 + 4 + payload_len; if buf.len() >= total_len {
392 let fin = b0 & 0x80 != 0;
394 let rsv1 = b0 & 0x40 != 0;
395 let rsv2 = b0 & 0x20 != 0;
396 let rsv3 = b0 & 0x10 != 0;
397
398 if (rsv1 && !self.allow_rsv1) || rsv2 || rsv3 {
400 return self.parse_slow(buf);
401 }
402
403 if let Some(opcode) = OpCode::from_u8(b0 & 0x0F) {
404 if opcode.is_control() && !fin {
406 return Err(Error::Protocol("control frame must not be fragmented"));
407 }
408
409 let mask = [buf[2], buf[3], buf[4], buf[5]];
411
412 buf.advance(6);
414 let mut payload = buf.split_to(payload_len);
415 apply_mask(&mut payload, mask);
416
417 return Ok(Some(Frame {
418 header: FrameHeader {
419 fin,
420 rsv1,
421 rsv2,
422 rsv3,
423 opcode,
424 masked: true,
425 payload_len: payload_len as u64,
426 mask: Some(mask),
427 },
428 payload: payload.freeze(),
429 }));
430 }
431 }
432 }
433 }
434
435 self.parse_slow(buf)
436 }
437
438 fn parse_slow(&mut self, buf: &mut BytesMut) -> Result<Option<Frame>> {
440 const DEBUG: bool = false;
441 loop {
442 if DEBUG && !buf.is_empty() {
443 eprintln!(
444 "[PARSER] State: {:?}, buf_len: {}, header_len: {}",
445 self.state,
446 buf.len(),
447 self.header_len
448 );
449 }
450 match self.state {
451 ParseState::Header => {
452 if buf.len() < 2 {
453 return Ok(None);
454 }
455
456 let b0 = buf[0];
458 let b1 = buf[1];
459
460 let fin = b0 & 0x80 != 0;
462 let rsv1 = b0 & 0x40 != 0;
463 let rsv2 = b0 & 0x20 != 0;
464 let rsv3 = b0 & 0x10 != 0;
465
466 if rsv1 && !self.allow_rsv1 {
469 return Err(Error::Protocol(
470 "RSV1 must be 0 (compression not negotiated)",
471 ));
472 }
473 if rsv2 || rsv3 {
474 return Err(Error::Protocol("RSV2 and RSV3 must be 0"));
475 }
476
477 let opcode =
478 OpCode::from_u8(b0 & 0x0F).ok_or(Error::InvalidFrame("invalid opcode"))?;
479
480 if opcode.is_control() && !fin {
482 return Err(Error::Protocol("control frame must not be fragmented"));
483 }
484
485 let masked = b1 & 0x80 != 0;
487 let len_byte = b1 & 0x7F;
488
489 if self.expect_masked && !masked {
491 return Err(Error::Protocol("client frames must be masked"));
492 }
493 if !self.expect_masked && masked {
494 return Err(Error::Protocol("server frames must not be masked"));
495 }
496
497 let (payload_len, header_size) = if len_byte <= 125 {
499 (len_byte as u64, 2)
500 } else if len_byte == 126 {
501 if buf.len() < 4 {
502 self.header_buf[0] = b0;
504 self.header_buf[1] = b1;
505 self.header_len = 2;
506 buf.advance(2); self.state = ParseState::ExtendedLen16;
508 return Ok(None);
509 }
510 let len = u16::from_be_bytes([buf[2], buf[3]]) as u64;
511 if len < 126 {
513 return Err(Error::Protocol("payload length not minimal"));
514 }
515 (len, 4)
516 } else {
517 if buf.len() < 10 {
519 self.header_buf[0] = b0;
520 self.header_buf[1] = b1;
521 self.header_len = 2;
522 buf.advance(2); self.state = ParseState::ExtendedLen64;
524 return Ok(None);
525 }
526 let len = u64::from_be_bytes([
527 buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
528 ]);
529 if len <= 0xFFFF {
531 return Err(Error::Protocol("payload length not minimal"));
532 }
533 if len >> 63 != 0 {
535 return Err(Error::Protocol("payload length MSB must be 0"));
536 }
537 (len, 10)
538 };
539
540 if opcode.is_control() && payload_len > 125 {
542 return Err(Error::Protocol("control frame too large"));
543 }
544
545 if payload_len > self.max_frame_size as u64 {
547 return Err(Error::FrameTooLarge);
548 }
549
550 let total_header = header_size + if masked { 4 } else { 0 };
552 if buf.len() < total_header {
553 let to_copy = buf.len().min(14);
555 self.header_buf[..to_copy].copy_from_slice(&buf[..to_copy]);
556 self.header_len = to_copy;
557 buf.advance(to_copy);
558
559 self.header = Some(FrameHeader {
561 fin,
562 rsv1,
563 rsv2,
564 rsv3,
565 opcode,
566 masked,
567 payload_len,
568 mask: None,
569 });
570
571 self.state = ParseState::Mask;
572 return Ok(None);
573 }
574
575 let mask = if masked {
576 Some([
577 buf[header_size],
578 buf[header_size + 1],
579 buf[header_size + 2],
580 buf[header_size + 3],
581 ])
582 } else {
583 None
584 };
585
586 buf.advance(total_header);
588
589 self.header = Some(FrameHeader {
590 fin,
591 rsv1,
592 rsv2,
593 rsv3,
594 opcode,
595 masked,
596 payload_len,
597 mask,
598 });
599 self.state = ParseState::Payload;
600 }
601
602 ParseState::ExtendedLen16 => {
603 let target_len = 4;
605 let needed = target_len - self.header_len;
606 if buf.len() < needed {
607 let to_copy = buf.len();
609 self.header_buf[self.header_len..self.header_len + to_copy]
610 .copy_from_slice(&buf[..to_copy]);
611 self.header_len += to_copy;
612 buf.advance(to_copy);
613 return Ok(None);
614 }
615
616 self.header_buf[self.header_len..target_len].copy_from_slice(&buf[..needed]);
618 buf.advance(needed);
619 self.header_len = target_len;
620
621 let payload_len =
622 u16::from_be_bytes([self.header_buf[2], self.header_buf[3]]) as u64;
623
624 if payload_len < 126 {
625 return Err(Error::Protocol("payload length not minimal"));
626 }
627
628 self.parse_header_with_len(payload_len)?;
630
631 if self.header.as_ref().unwrap().masked {
632 self.state = ParseState::Mask;
633 } else {
634 self.state = ParseState::Payload;
635 }
636 }
637
638 ParseState::ExtendedLen64 => {
639 let target_len = 10;
641 let needed = target_len - self.header_len;
642 if buf.len() < needed {
643 let to_copy = buf.len();
645 self.header_buf[self.header_len..self.header_len + to_copy]
646 .copy_from_slice(&buf[..to_copy]);
647 self.header_len += to_copy;
648 buf.advance(to_copy);
649 return Ok(None);
650 }
651
652 self.header_buf[self.header_len..target_len].copy_from_slice(&buf[..needed]);
653 buf.advance(needed);
654 self.header_len = target_len;
655
656 let payload_len = u64::from_be_bytes([
657 self.header_buf[2],
658 self.header_buf[3],
659 self.header_buf[4],
660 self.header_buf[5],
661 self.header_buf[6],
662 self.header_buf[7],
663 self.header_buf[8],
664 self.header_buf[9],
665 ]);
666
667 if payload_len <= 0xFFFF {
668 return Err(Error::Protocol("payload length not minimal"));
669 }
670 if payload_len >> 63 != 0 {
671 return Err(Error::Protocol("payload length MSB must be 0"));
672 }
673
674 self.parse_header_with_len(payload_len)?;
675
676 if self.header.as_ref().unwrap().masked {
677 self.state = ParseState::Mask;
678 } else {
679 self.state = ParseState::Payload;
680 }
681 }
682
683 ParseState::Mask => {
684 let header = self.header.as_mut().unwrap();
685 let header_base = if header.payload_len > MEDIUM_MESSAGE_THRESHOLD as u64 {
687 10 } else if header.payload_len > SMALL_MESSAGE_THRESHOLD as u64 {
689 4 } else {
691 2 };
693 let target_len = header_base + 4;
695 let needed = target_len - self.header_len;
696
697 if DEBUG {
698 eprintln!(
699 "[PARSER] Mask state: header_base={}, target_len={}, header_len={}, needed={}, payload_len={}",
700 header_base, target_len, self.header_len, needed, header.payload_len
701 );
702 eprintln!(
703 "[PARSER] header_buf so far: {:?}",
704 &self.header_buf[..self.header_len]
705 );
706 }
707
708 if buf.len() < needed {
709 let to_copy = buf.len();
711 self.header_buf[self.header_len..self.header_len + to_copy]
712 .copy_from_slice(&buf[..to_copy]);
713 self.header_len += to_copy;
714 buf.advance(to_copy);
715 return Ok(None);
716 }
717
718 self.header_buf[self.header_len..target_len].copy_from_slice(&buf[..needed]);
720 buf.advance(needed);
721 self.header_len = target_len;
722
723 header.mask = Some([
725 self.header_buf[header_base],
726 self.header_buf[header_base + 1],
727 self.header_buf[header_base + 2],
728 self.header_buf[header_base + 3],
729 ]);
730
731 self.state = ParseState::Payload;
732 }
733
734 ParseState::Payload => {
735 let header = self.header.as_ref().unwrap();
736 let payload_len = header.payload_len as usize;
737
738 if DEBUG {
739 eprintln!(
740 "[PARSER] Payload state: need {} bytes, have {} bytes (opcode: {:?}, fin: {}, rsv1: {})",
741 payload_len,
742 buf.len(),
743 header.opcode,
744 header.fin,
745 header.rsv1
746 );
747 if !buf.is_empty() {
748 eprintln!(
749 "[PARSER] First 16 bytes of buffer: {:?}",
750 &buf[..buf.len().min(16)]
751 );
752 }
753 }
754
755 if buf.len() < payload_len {
756 if DEBUG {
757 eprintln!("[PARSER] Not enough payload data, waiting...");
758 }
759 return Ok(None);
760 }
761
762 if DEBUG {
763 eprintln!(
764 "[PARSER] Extracting payload of {} bytes, buf will have {} bytes remaining",
765 payload_len,
766 buf.len() - payload_len
767 );
768 }
769
770 let mut payload = buf.split_to(payload_len);
772
773 if let Some(mask) = header.mask {
774 apply_mask(&mut payload, mask);
775 }
776
777 let frame = Frame {
778 header: self.header.take().unwrap(),
779 payload: payload.freeze(),
780 };
781
782 if DEBUG {
783 eprintln!(
784 "[PARSER] Frame complete! Resetting parser state. Buffer now has {} bytes",
785 buf.len()
786 );
787 }
788
789 self.reset();
790 return Ok(Some(frame));
791 }
792 }
793 }
794 }
795
796 fn parse_header_with_len(&mut self, payload_len: u64) -> Result<()> {
798 let b0 = self.header_buf[0];
799 let b1 = self.header_buf[1];
800
801 let fin = b0 & 0x80 != 0;
802 let rsv1 = b0 & 0x40 != 0;
803 let rsv2 = b0 & 0x20 != 0;
804 let rsv3 = b0 & 0x10 != 0;
805 let opcode = OpCode::from_u8(b0 & 0x0F).ok_or(Error::InvalidFrame("invalid opcode"))?;
806 let masked = b1 & 0x80 != 0;
807
808 if opcode.is_control() && payload_len > 125 {
809 return Err(Error::Protocol("control frame too large"));
810 }
811
812 if payload_len > self.max_frame_size as u64 {
813 return Err(Error::FrameTooLarge);
814 }
815
816 self.header = Some(FrameHeader {
817 fin,
818 rsv1,
819 rsv2,
820 rsv3,
821 opcode,
822 masked,
823 payload_len,
824 mask: None,
825 });
826
827 Ok(())
828 }
829}
830
831#[inline]
836pub fn encode_frame(
837 buf: &mut BytesMut,
838 opcode: OpCode,
839 payload: &[u8],
840 fin: bool,
841 mask: Option<[u8; 4]>,
842) {
843 encode_frame_with_rsv(buf, opcode, payload, fin, mask, false)
844}
845
846#[inline]
850pub fn encode_frame_with_rsv(
851 buf: &mut BytesMut,
852 opcode: OpCode,
853 payload: &[u8],
854 fin: bool,
855 mask: Option<[u8; 4]>,
856 rsv1: bool,
857) {
858 let payload_len = payload.len();
859
860 let ext_len_size = if payload_len > MEDIUM_MESSAGE_THRESHOLD {
862 8
863 } else if payload_len > SMALL_MESSAGE_THRESHOLD {
864 2
865 } else {
866 0
867 };
868 let mask_size = if mask.is_some() { 4 } else { 0 };
869 let header_size = 2 + ext_len_size + mask_size;
870 let total_size = header_size + payload_len;
871
872 buf.reserve(total_size);
874
875 unsafe {
877 let base = buf.as_mut_ptr().add(buf.len());
878 let mut offset = 0;
879
880 let mut b0 = opcode as u8;
882 if fin {
883 b0 |= 0x80;
884 }
885 if rsv1 {
886 b0 |= 0x40;
887 }
888 base.add(offset).write(b0);
889 offset += 1;
890
891 let mask_bit = if mask.is_some() { 0x80u8 } else { 0x00u8 };
893
894 if payload_len <= SMALL_MESSAGE_THRESHOLD {
895 base.add(offset).write(mask_bit | payload_len as u8);
896 offset += 1;
897 } else if payload_len <= MEDIUM_MESSAGE_THRESHOLD {
898 base.add(offset).write(mask_bit | 126);
899 offset += 1;
900 let len_bytes = (payload_len as u16).to_be_bytes();
901 std::ptr::copy_nonoverlapping(len_bytes.as_ptr(), base.add(offset), 2);
902 offset += 2;
903 } else {
904 base.add(offset).write(mask_bit | 127);
905 offset += 1;
906 let len_bytes = (payload_len as u64).to_be_bytes();
907 std::ptr::copy_nonoverlapping(len_bytes.as_ptr(), base.add(offset), 8);
908 offset += 8;
909 }
910
911 if let Some(m) = mask {
913 std::ptr::copy_nonoverlapping(m.as_ptr(), base.add(offset), 4);
915 offset += 4;
916
917 let payload_dst = base.add(offset);
919 encode_payload_masked_inline(payload_dst, payload.as_ptr(), payload_len, m);
920 } else {
921 std::ptr::copy_nonoverlapping(payload.as_ptr(), base.add(offset), payload_len);
923 }
924
925 buf.set_len(buf.len() + total_size);
927 }
928}
929
930#[inline]
934unsafe fn encode_payload_masked_inline(dst: *mut u8, src: *const u8, len: usize, mask: [u8; 4]) {
935 unsafe {
936 let mask_u32 = u32::from_ne_bytes(mask);
937
938 let mut i = 0;
940
941 while i + 8 <= len {
943 let mask_u64 = ((mask_u32 as u64) << 32) | (mask_u32 as u64);
944 let src_val = std::ptr::read_unaligned(src.add(i) as *const u64);
945 let masked = src_val ^ mask_u64;
946 std::ptr::write_unaligned(dst.add(i) as *mut u64, masked);
947 i += 8;
948 }
949
950 if i + 4 <= len {
952 let src_val = std::ptr::read_unaligned(src.add(i) as *const u32);
953 let masked = src_val ^ mask_u32;
954 std::ptr::write_unaligned(dst.add(i) as *mut u32, masked);
955 i += 4;
956 }
957
958 while i < len {
960 dst.add(i).write(src.add(i).read() ^ mask[i & 3]);
961 i += 1;
962 }
963 }
964}
965
966#[cfg(test)]
967mod tests {
968 use super::*;
969
970 #[test]
971 fn test_opcode() {
972 assert!(OpCode::Ping.is_control());
973 assert!(OpCode::Pong.is_control());
974 assert!(OpCode::Close.is_control());
975 assert!(!OpCode::Text.is_control());
976 assert!(!OpCode::Binary.is_control());
977 assert!(OpCode::Text.is_data());
978 assert!(OpCode::Binary.is_data());
979 assert!(OpCode::Continuation.is_data());
980 }
981
982 #[test]
983 fn test_parse_small_unmasked() {
984 let mut parser = FrameParser::new(1024 * 1024, false);
985 let mut buf = BytesMut::from(&[0x81, 0x05, b'h', b'e', b'l', b'l', b'o'][..]);
986
987 let frame = parser.parse(&mut buf).unwrap().unwrap();
988 assert!(frame.header.fin);
989 assert_eq!(frame.header.opcode, OpCode::Text);
990 assert_eq!(frame.payload.as_ref(), b"hello");
991 }
992
993 #[test]
994 fn test_parse_small_masked() {
995 let mut parser = FrameParser::new(1024 * 1024, true);
996 let mask = [0x37, 0xfa, 0x21, 0x3d];
997
998 let mut payload = *b"Hello";
1000 apply_mask(&mut payload, mask);
1001
1002 let mut buf = BytesMut::new();
1003 buf.put_u8(0x81); buf.put_u8(0x85); buf.put_slice(&mask);
1006 buf.put_slice(&payload);
1007
1008 let frame = parser.parse(&mut buf).unwrap().unwrap();
1009 assert_eq!(frame.payload.as_ref(), b"Hello");
1010 }
1011
1012 #[test]
1013 fn test_parse_medium_length() {
1014 let mut parser = FrameParser::new(1024 * 1024, false);
1015 let payload = vec![0x42u8; 200];
1016
1017 let mut buf = BytesMut::new();
1018 buf.put_u8(0x82); buf.put_u8(126); buf.put_u16(200); buf.put_slice(&payload);
1022
1023 let frame = parser.parse(&mut buf).unwrap().unwrap();
1024 assert_eq!(frame.header.opcode, OpCode::Binary);
1025 assert_eq!(frame.payload.len(), 200);
1026 }
1027
1028 #[test]
1029 fn test_encode_frame() {
1030 let mut buf = BytesMut::new();
1031 encode_frame(&mut buf, OpCode::Text, b"hello", true, None);
1032
1033 assert_eq!(buf[0], 0x81); assert_eq!(buf[1], 0x05); assert_eq!(&buf[2..], b"hello");
1036 }
1037
1038 #[test]
1039 fn test_encode_frame_masked() {
1040 let mask = [0x01, 0x02, 0x03, 0x04];
1041 let mut buf = BytesMut::new();
1042 encode_frame(&mut buf, OpCode::Text, b"test", true, Some(mask));
1043
1044 assert_eq!(buf[0], 0x81); assert_eq!(buf[1], 0x84); assert_eq!(&buf[2..6], &mask);
1047
1048 let mut payload = buf[6..].to_vec();
1050 apply_mask(&mut payload, mask);
1051 assert_eq!(&payload, b"test");
1052 }
1053
1054 #[test]
1055 fn test_control_frame_fragmentation() {
1056 let mut parser = FrameParser::new(1024, false);
1057 let mut buf = BytesMut::from(&[0x09, 0x00][..]); buf[0] = 0x09; let result = parser.parse(&mut buf);
1061 assert!(result.is_err());
1062 }
1063
1064 #[test]
1065 fn test_close_frame() {
1066 let frame = Frame::close(1000, "goodbye");
1067 assert_eq!(frame.header.opcode, OpCode::Close);
1068
1069 let close = frame.parse_close().unwrap();
1070 assert_eq!(close.code, 1000);
1071 assert_eq!(close.reason, "goodbye");
1072 }
1073}