1use bytes::{Bytes, BytesMut};
9
10use crate::error::{CloseReason, Error, Result};
11#[cfg(feature = "permessage-deflate")]
12use crate::frame::encode_frame_with_rsv;
13use crate::frame::{Frame, FrameParser, OpCode, encode_frame};
14use crate::utf8::{validate_utf8, validate_utf8_incomplete};
15
16#[cfg(feature = "permessage-deflate")]
17use crate::deflate::{DeflateConfig, DeflateContext};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum Role {
22 Client,
24 Server,
26}
27
28#[derive(Debug, Clone)]
33pub enum Message {
34 Text(Bytes),
36 Binary(Bytes),
38 Ping(Bytes),
40 Pong(Bytes),
42 Close(Option<CloseReason>),
44}
45
46impl Message {
47 #[inline]
49 pub fn text(s: impl Into<String>) -> Self {
50 Message::Text(Bytes::from(s.into()))
51 }
52
53 #[inline]
55 pub fn binary(data: impl Into<Bytes>) -> Self {
56 Message::Binary(data.into())
57 }
58
59 #[inline]
61 pub fn ping(data: impl Into<Bytes>) -> Self {
62 Message::Ping(data.into())
63 }
64
65 #[inline]
67 pub fn pong(data: impl Into<Bytes>) -> Self {
68 Message::Pong(data.into())
69 }
70
71 #[inline]
73 pub fn is_close(&self) -> bool {
74 matches!(self, Message::Close(_))
75 }
76
77 #[inline]
79 pub fn is_text(&self) -> bool {
80 matches!(self, Message::Text(_))
81 }
82
83 #[inline]
85 pub fn is_binary(&self) -> bool {
86 matches!(self, Message::Binary(_))
87 }
88
89 #[inline]
91 pub fn is_ping(&self) -> bool {
92 matches!(self, Message::Ping(_))
93 }
94
95 #[inline]
97 pub fn is_pong(&self) -> bool {
98 matches!(self, Message::Pong(_))
99 }
100
101 #[inline]
103 pub fn is_control(&self) -> bool {
104 matches!(
105 self,
106 Message::Ping(_) | Message::Pong(_) | Message::Close(_)
107 )
108 }
109
110 #[inline]
115 pub fn as_text(&self) -> Option<&str> {
116 match self {
117 Message::Text(b) => {
118 Some(unsafe { std::str::from_utf8_unchecked(b) })
120 }
121 _ => None,
122 }
123 }
124
125 #[inline]
127 pub fn as_bytes(&self) -> &[u8] {
128 match self {
129 Message::Text(b) => b,
130 Message::Binary(b) => b,
131 Message::Ping(b) => b,
132 Message::Pong(b) => b,
133 Message::Close(_) => &[],
134 }
135 }
136
137 pub fn into_text(self) -> Option<String> {
141 match self {
142 Message::Text(b) => {
143 Some(unsafe { String::from_utf8_unchecked(b.to_vec()) })
145 }
146 _ => None,
147 }
148 }
149
150 #[inline]
154 pub fn text_bytes(&self) -> Option<&Bytes> {
155 match self {
156 Message::Text(b) => Some(b),
157 _ => None,
158 }
159 }
160
161 pub fn into_bytes(self) -> Bytes {
163 match self {
164 Message::Text(b) => b,
165 Message::Binary(b) => b,
166 Message::Ping(b) => b,
167 Message::Pong(b) => b,
168 Message::Close(_) => Bytes::new(),
169 }
170 }
171}
172
173impl From<String> for Message {
174 fn from(s: String) -> Self {
175 Message::Text(Bytes::from(s))
176 }
177}
178
179impl From<&str> for Message {
180 fn from(s: &str) -> Self {
181 Message::Text(Bytes::copy_from_slice(s.as_bytes()))
182 }
183}
184
185impl From<Vec<u8>> for Message {
186 fn from(v: Vec<u8>) -> Self {
187 Message::Binary(Bytes::from(v))
188 }
189}
190
191impl From<Bytes> for Message {
192 fn from(b: Bytes) -> Self {
193 Message::Binary(b)
194 }
195}
196
197impl From<&[u8]> for Message {
198 fn from(b: &[u8]) -> Self {
199 Message::Binary(Bytes::copy_from_slice(b))
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205enum State {
206 Open,
208 CloseSent,
210 CloseReceived,
212 Closed,
214}
215
216pub struct Protocol {
220 pub(crate) role: Role,
222 state: State,
224 pub(crate) parser: FrameParser,
226 pub(crate) fragment_buf: BytesMut,
228 pub(crate) fragment_opcode: Option<OpCode>,
230 pub(crate) max_message_size: usize,
232 pending_close: Option<CloseReason>,
234}
235
236impl Protocol {
237 pub fn new(role: Role, max_frame_size: usize, max_message_size: usize) -> Self {
239 let expect_masked = role == Role::Server;
240
241 Self {
242 role,
243 state: State::Open,
244 parser: FrameParser::new(max_frame_size, expect_masked),
245 fragment_buf: BytesMut::new(),
246 fragment_opcode: None,
247 max_message_size,
248 pending_close: None,
249 }
250 }
251
252 #[inline]
254 pub fn is_closed(&self) -> bool {
255 self.state == State::Closed
256 }
257
258 #[inline]
260 pub fn is_closing(&self) -> bool {
261 matches!(self.state, State::CloseSent | State::CloseReceived)
262 }
263
264 #[inline]
269 pub fn process(&mut self, buf: &mut BytesMut) -> Result<Vec<Message>> {
270 let mut messages = Vec::new();
271 self.process_into(buf, &mut messages)?;
272 Ok(messages)
273 }
274
275 #[inline]
279 pub fn process_into(&mut self, buf: &mut BytesMut, messages: &mut Vec<Message>) -> Result<()> {
280 messages.clear();
281
282 while !buf.is_empty() {
283 match self.parser.parse(buf)? {
284 Some(frame) => {
285 if let Some(msg) = self.handle_frame(frame)? {
286 messages.push(msg);
287 }
288 }
289 None => break,
290 }
291 }
292
293 Ok(())
294 }
295
296 fn handle_frame(&mut self, frame: Frame) -> Result<Option<Message>> {
298 match frame.header.opcode {
299 OpCode::Continuation => self.handle_continuation(frame),
300 OpCode::Text => self.handle_text(frame),
301 OpCode::Binary => self.handle_binary(frame),
302 OpCode::Close => self.handle_close(frame),
303 OpCode::Ping => self.handle_ping(frame),
304 OpCode::Pong => self.handle_pong(frame),
305 }
306 }
307
308 fn handle_text(&mut self, frame: Frame) -> Result<Option<Message>> {
310 if self.fragment_opcode.is_some() {
311 return Err(Error::Protocol("expected continuation frame"));
312 }
313
314 if frame.header.fin {
315 if !validate_utf8(&frame.payload) {
317 return Err(Error::InvalidUtf8);
318 }
319 Ok(Some(Message::Text(frame.payload)))
321 } else {
322 self.start_fragment(OpCode::Text, frame.payload)?;
324 Ok(None)
325 }
326 }
327
328 fn handle_binary(&mut self, frame: Frame) -> Result<Option<Message>> {
330 if self.fragment_opcode.is_some() {
331 return Err(Error::Protocol("expected continuation frame"));
332 }
333
334 if frame.header.fin {
335 Ok(Some(Message::Binary(frame.payload)))
337 } else {
338 self.start_fragment(OpCode::Binary, frame.payload)?;
340 Ok(None)
341 }
342 }
343
344 fn handle_continuation(&mut self, frame: Frame) -> Result<Option<Message>> {
346 let opcode = self
347 .fragment_opcode
348 .ok_or(Error::Protocol("unexpected continuation frame"))?;
349
350 let new_size = self.fragment_buf.len() + frame.payload.len();
352 if new_size > self.max_message_size {
353 return Err(Error::MessageTooLarge);
354 }
355
356 self.fragment_buf.extend_from_slice(&frame.payload);
357
358 if frame.header.fin {
359 self.complete_fragment(opcode)
361 } else {
362 if opcode == OpCode::Text {
364 let (valid, _incomplete) = validate_utf8_incomplete(&self.fragment_buf);
365 if !valid {
366 return Err(Error::InvalidUtf8);
367 }
368 }
369 Ok(None)
370 }
371 }
372
373 pub(crate) fn start_fragment(&mut self, opcode: OpCode, payload: Bytes) -> Result<()> {
375 if payload.len() > self.max_message_size {
376 return Err(Error::MessageTooLarge);
377 }
378
379 self.fragment_opcode = Some(opcode);
380 self.fragment_buf.clear();
381 self.fragment_buf.extend_from_slice(&payload);
382
383 if opcode == OpCode::Text {
385 let (valid, _incomplete) = validate_utf8_incomplete(&self.fragment_buf);
386 if !valid {
387 return Err(Error::InvalidUtf8);
388 }
389 }
390
391 Ok(())
392 }
393
394 fn complete_fragment(&mut self, opcode: OpCode) -> Result<Option<Message>> {
396 self.fragment_opcode = None;
397 let data = self.fragment_buf.split().freeze();
398
399 match opcode {
400 OpCode::Text => {
401 if !validate_utf8(&data) {
402 return Err(Error::InvalidUtf8);
403 }
404 Ok(Some(Message::Text(data)))
406 }
407 OpCode::Binary => Ok(Some(Message::Binary(data))),
408 _ => Err(Error::Protocol("invalid fragment opcode")),
409 }
410 }
411
412 pub(crate) fn handle_close(&mut self, frame: Frame) -> Result<Option<Message>> {
414 let reason = if frame.payload.len() >= 2 {
415 let code = u16::from_be_bytes([frame.payload[0], frame.payload[1]]);
416
417 if !CloseReason::is_valid_code(code) && !(3000..=4999).contains(&code) {
419 return Err(Error::InvalidCloseCode(code));
420 }
421
422 let reason_text = if frame.payload.len() > 2 {
423 let text = &frame.payload[2..];
424 if !validate_utf8(text) {
425 return Err(Error::InvalidUtf8);
426 }
427 String::from_utf8_lossy(text).into_owned()
428 } else {
429 String::new()
430 };
431
432 Some(CloseReason::new(code, reason_text))
433 } else if frame.payload.is_empty() {
434 None
435 } else {
436 return Err(Error::Protocol("invalid close frame payload"));
438 };
439
440 match self.state {
441 State::Open => {
442 self.state = State::CloseReceived;
443 self.pending_close = reason.clone();
444 }
445 State::CloseSent => {
446 self.state = State::Closed;
447 }
448 _ => {}
449 }
450
451 Ok(Some(Message::Close(reason)))
452 }
453
454 pub(crate) fn handle_ping(&mut self, frame: Frame) -> Result<Option<Message>> {
456 Ok(Some(Message::Ping(frame.payload)))
457 }
458
459 pub(crate) fn handle_pong(&mut self, frame: Frame) -> Result<Option<Message>> {
461 Ok(Some(Message::Pong(frame.payload)))
462 }
463
464 pub fn encode_message(&mut self, msg: &Message, buf: &mut BytesMut) -> Result<()> {
466 let mask = if self.role == Role::Client {
467 Some(crate::mask::generate_mask_fast())
468 } else {
469 None
470 };
471
472 match msg {
473 Message::Text(b) => {
474 encode_frame(buf, OpCode::Text, b, true, mask);
475 }
476 Message::Binary(b) => {
477 encode_frame(buf, OpCode::Binary, b, true, mask);
478 }
479 Message::Ping(b) => {
480 encode_frame(buf, OpCode::Ping, b, true, mask);
481 }
482 Message::Pong(b) => {
483 encode_frame(buf, OpCode::Pong, b, true, mask);
484 }
485 Message::Close(reason) => {
486 if self.state == State::Open {
487 self.state = State::CloseSent;
488 }
489
490 let payload = if let Some(r) = reason {
491 let mut p = BytesMut::with_capacity(2 + r.reason.len());
492 p.extend_from_slice(&r.code.to_be_bytes());
493 p.extend_from_slice(r.reason.as_bytes());
494 p.freeze()
495 } else {
496 Bytes::new()
497 };
498
499 encode_frame(buf, OpCode::Close, &payload, true, mask);
500 }
501 }
502
503 Ok(())
504 }
505
506 pub fn encode_pong(&mut self, ping_data: &[u8], buf: &mut BytesMut) {
508 let mask = if self.role == Role::Client {
509 Some(crate::mask::generate_mask_fast())
510 } else {
511 None
512 };
513 encode_frame(buf, OpCode::Pong, ping_data, true, mask);
514 }
515
516 pub fn encode_close_response(&mut self, buf: &mut BytesMut) {
518 let mask = if self.role == Role::Client {
519 Some(crate::mask::generate_mask_fast())
520 } else {
521 None
522 };
523
524 let payload = if let Some(ref reason) = self.pending_close {
525 let mut p = BytesMut::with_capacity(2 + reason.reason.len());
526 p.extend_from_slice(&reason.code.to_be_bytes());
527 p.extend_from_slice(reason.reason.as_bytes());
528 p.freeze()
529 } else {
530 Bytes::new()
531 };
532
533 encode_frame(buf, OpCode::Close, &payload, true, mask);
534
535 if self.state == State::CloseReceived {
536 self.state = State::Closed;
537 }
538 }
539
540 #[cfg(feature = "permessage-deflate")]
542 pub fn enable_compression(&mut self) {
543 self.parser.set_compression(true);
544 }
545}
546
547#[cfg(feature = "permessage-deflate")]
549pub struct CompressedProtocol {
550 inner: Protocol,
552 deflate: DeflateContext,
554 fragment_compressed: bool,
556 decompress_buf: BytesMut,
558}
559
560#[cfg(feature = "permessage-deflate")]
561impl CompressedProtocol {
562 pub fn server(max_frame_size: usize, max_message_size: usize, config: DeflateConfig) -> Self {
564 let mut inner = Protocol::new(Role::Server, max_frame_size, max_message_size);
565 inner.enable_compression();
566
567 Self {
568 inner,
569 deflate: DeflateContext::server(config),
570 fragment_compressed: false,
571 decompress_buf: BytesMut::new(),
572 }
573 }
574
575 pub fn client(max_frame_size: usize, max_message_size: usize, config: DeflateConfig) -> Self {
577 let mut inner = Protocol::new(Role::Client, max_frame_size, max_message_size);
578 inner.enable_compression();
579
580 Self {
581 inner,
582 deflate: DeflateContext::client(config),
583 fragment_compressed: false,
584 decompress_buf: BytesMut::new(),
585 }
586 }
587
588 #[inline]
590 pub fn is_closed(&self) -> bool {
591 self.inner.is_closed()
592 }
593
594 #[inline]
596 pub fn is_closing(&self) -> bool {
597 self.inner.is_closing()
598 }
599
600 pub fn process(&mut self, buf: &mut BytesMut) -> Result<Vec<Message>> {
602 let mut messages = Vec::new();
603 self.process_into(buf, &mut messages)?;
604 Ok(messages)
605 }
606
607 #[inline]
609 pub fn process_into(&mut self, buf: &mut BytesMut, messages: &mut Vec<Message>) -> Result<()> {
610 const DEBUG: bool = false;
611 messages.clear();
612
613 while !buf.is_empty() {
614 if DEBUG {
615 eprintln!("[PROTOCOL] process_into loop: buf has {} bytes", buf.len());
616 }
617 match self.inner.parser.parse(buf)? {
618 Some(frame) => {
619 if DEBUG {
620 eprintln!("[PROTOCOL] Parsed frame, handling...");
621 }
622 if let Some(msg) = self.handle_frame(frame)? {
623 messages.push(msg);
624 if DEBUG {
625 eprintln!("[PROTOCOL] Added message to output");
626 }
627 } else if DEBUG {
628 eprintln!("[PROTOCOL] No message from handle_frame (fragment or control)");
629 }
630 }
631 None => {
632 if DEBUG {
633 eprintln!("[PROTOCOL] Parser returned None, breaking loop");
634 }
635 break;
636 }
637 }
638 }
639
640 if DEBUG {
641 eprintln!("[PROTOCOL] process_into done, {} messages", messages.len());
642 }
643
644 Ok(())
645 }
646
647 fn handle_frame(&mut self, frame: Frame) -> Result<Option<Message>> {
649 let is_compressed = frame.header.rsv1;
650
651 match frame.header.opcode {
652 OpCode::Continuation => self.handle_continuation(frame),
653 OpCode::Text => self.handle_text(frame, is_compressed),
654 OpCode::Binary => self.handle_binary(frame, is_compressed),
655 OpCode::Close => self.inner.handle_close(frame),
656 OpCode::Ping => self.inner.handle_ping(frame),
657 OpCode::Pong => self.inner.handle_pong(frame),
658 }
659 }
660
661 fn handle_text(&mut self, frame: Frame, compressed: bool) -> Result<Option<Message>> {
663 const DEBUG: bool = false;
664
665 if self.inner.fragment_opcode.is_some() {
666 return Err(Error::Protocol("expected continuation frame"));
667 }
668
669 if frame.header.fin {
670 if DEBUG {
672 eprintln!(
673 "[PROTOCOL] Complete text message, compressed={}, size={}",
674 compressed,
675 frame.payload.len()
676 );
677 }
678 let payload = if compressed {
679 self.deflate
680 .decompress(&frame.payload, self.inner.max_message_size)?
681 } else {
682 frame.payload
683 };
684
685 if !validate_utf8(&payload) {
686 return Err(Error::InvalidUtf8);
687 }
688 Ok(Some(Message::Text(payload)))
690 } else {
691 if DEBUG {
693 eprintln!(
694 "[PROTOCOL] Starting fragmented text message, compressed={}, first fragment size={}",
695 compressed,
696 frame.payload.len()
697 );
698 }
699 self.fragment_compressed = compressed;
700
701 if compressed {
704 self.inner.start_fragment(OpCode::Binary, frame.payload)?;
705 self.inner.fragment_opcode = Some(OpCode::Text);
707 } else {
708 self.inner.start_fragment(OpCode::Text, frame.payload)?;
709 }
710 Ok(None)
711 }
712 }
713
714 fn handle_binary(&mut self, frame: Frame, compressed: bool) -> Result<Option<Message>> {
716 if self.inner.fragment_opcode.is_some() {
717 return Err(Error::Protocol("expected continuation frame"));
718 }
719
720 if frame.header.fin {
721 let payload = if compressed {
723 self.deflate
724 .decompress(&frame.payload, self.inner.max_message_size)?
725 } else {
726 frame.payload
727 };
728 Ok(Some(Message::Binary(payload)))
729 } else {
730 self.fragment_compressed = compressed;
732 self.inner.start_fragment(OpCode::Binary, frame.payload)?;
733 Ok(None)
734 }
735 }
736
737 fn handle_continuation(&mut self, frame: Frame) -> Result<Option<Message>> {
739 let opcode = self
740 .inner
741 .fragment_opcode
742 .ok_or(Error::Protocol("unexpected continuation frame"))?;
743
744 let new_size = self.inner.fragment_buf.len() + frame.payload.len();
746 if new_size > self.inner.max_message_size {
747 return Err(Error::MessageTooLarge);
748 }
749
750 self.inner.fragment_buf.extend_from_slice(&frame.payload);
751
752 if frame.header.fin {
753 self.complete_fragment(opcode)
755 } else {
756 Ok(None)
757 }
758 }
759
760 fn complete_fragment(&mut self, opcode: OpCode) -> Result<Option<Message>> {
762 self.inner.fragment_opcode = None;
763 let compressed_data = self.inner.fragment_buf.split().freeze();
764
765 let data = if self.fragment_compressed {
767 self.fragment_compressed = false;
768 self.deflate
769 .decompress(&compressed_data, self.inner.max_message_size)?
770 } else {
771 compressed_data
772 };
773
774 match opcode {
775 OpCode::Text => {
776 if !validate_utf8(&data) {
777 return Err(Error::InvalidUtf8);
778 }
779 Ok(Some(Message::Text(data)))
781 }
782 OpCode::Binary => Ok(Some(Message::Binary(data))),
783 _ => Err(Error::Protocol("invalid fragment opcode")),
784 }
785 }
786
787 pub fn encode_message(&mut self, msg: &Message, buf: &mut BytesMut) -> Result<()> {
789 let mask = if self.inner.role == Role::Client {
790 Some(crate::mask::generate_mask_fast())
791 } else {
792 None
793 };
794
795 match msg {
796 Message::Text(b) => {
797 if let Some(compressed) = self.deflate.compress(b)? {
799 encode_frame_with_rsv(buf, OpCode::Text, &compressed, true, mask, true);
800 } else {
801 encode_frame(buf, OpCode::Text, b, true, mask);
802 }
803 }
804 Message::Binary(b) => {
805 if let Some(compressed) = self.deflate.compress(b)? {
807 encode_frame_with_rsv(buf, OpCode::Binary, &compressed, true, mask, true);
808 } else {
809 encode_frame(buf, OpCode::Binary, b, true, mask);
810 }
811 }
812 Message::Ping(b) => {
813 encode_frame(buf, OpCode::Ping, b, true, mask);
815 }
816 Message::Pong(b) => {
817 encode_frame(buf, OpCode::Pong, b, true, mask);
818 }
819 Message::Close(_) => {
820 self.inner.encode_message(msg, buf)?;
821 }
822 }
823
824 Ok(())
825 }
826
827 pub fn encode_pong(&mut self, ping_data: &[u8], buf: &mut BytesMut) {
829 self.inner.encode_pong(ping_data, buf);
830 }
831
832 pub fn encode_close_response(&mut self, buf: &mut BytesMut) {
834 self.inner.encode_close_response(buf);
835 }
836
837 pub fn split(
842 self,
843 max_frame_size: usize,
844 max_message_size: usize,
845 ) -> (CompressedReaderProtocol, CompressedWriterProtocol) {
846 let role = self.inner.role;
847
848 let reader = CompressedReaderProtocol {
850 role,
851 parser: FrameParser::new(max_frame_size, role == Role::Server),
852 fragment_buf: self.inner.fragment_buf,
853 fragment_opcode: self.inner.fragment_opcode,
854 max_message_size,
855 decoder: self.deflate.decoder,
856 fragment_compressed: self.fragment_compressed,
857 };
858
859 let writer = CompressedWriterProtocol {
861 role,
862 encoder: self.deflate.encoder,
863 };
864
865 (reader, writer)
866 }
867}
868
869#[cfg(feature = "permessage-deflate")]
873pub struct CompressedReaderProtocol {
874 role: Role,
876 parser: FrameParser,
878 fragment_buf: BytesMut,
880 fragment_opcode: Option<OpCode>,
882 max_message_size: usize,
884 decoder: crate::deflate::DeflateDecoder,
886 fragment_compressed: bool,
888}
889
890#[cfg(feature = "permessage-deflate")]
891impl CompressedReaderProtocol {
892 pub fn server(max_frame_size: usize, max_message_size: usize, config: &DeflateConfig) -> Self {
894 Self {
895 role: Role::Server,
896 parser: FrameParser::new(max_frame_size, true),
897 fragment_buf: BytesMut::new(),
898 fragment_opcode: None,
899 max_message_size,
900 decoder: crate::deflate::DeflateDecoder::new(
901 config.client_max_window_bits,
902 config.client_no_context_takeover,
903 ),
904 fragment_compressed: false,
905 }
906 }
907
908 pub fn client(max_frame_size: usize, max_message_size: usize, config: &DeflateConfig) -> Self {
910 Self {
911 role: Role::Client,
912 parser: FrameParser::new(max_frame_size, false),
913 fragment_buf: BytesMut::new(),
914 fragment_opcode: None,
915 max_message_size,
916 decoder: crate::deflate::DeflateDecoder::new(
917 config.server_max_window_bits,
918 config.server_no_context_takeover,
919 ),
920 fragment_compressed: false,
921 }
922 }
923
924 pub fn process(&mut self, buf: &mut BytesMut) -> Result<Vec<Message>> {
926 let mut messages = Vec::new();
927 self.process_into(buf, &mut messages)?;
928 Ok(messages)
929 }
930
931 pub fn process_into(&mut self, buf: &mut BytesMut, messages: &mut Vec<Message>) -> Result<()> {
933 messages.clear();
934
935 self.parser.set_compression(true);
937
938 while !buf.is_empty() {
939 match self.parser.parse(buf)? {
940 Some(frame) => {
941 if let Some(msg) = self.handle_frame(frame)? {
942 messages.push(msg);
943 }
944 }
945 None => break,
946 }
947 }
948
949 Ok(())
950 }
951
952 fn handle_frame(&mut self, frame: Frame) -> Result<Option<Message>> {
954 let is_compressed = frame.header.rsv1;
955
956 match frame.header.opcode {
957 OpCode::Continuation => self.handle_continuation(frame),
958 OpCode::Text => self.handle_text(frame, is_compressed),
959 OpCode::Binary => self.handle_binary(frame, is_compressed),
960 OpCode::Close => self.handle_close(frame),
961 OpCode::Ping => Ok(Some(Message::Ping(frame.payload))),
962 OpCode::Pong => Ok(Some(Message::Pong(frame.payload))),
963 }
964 }
965
966 fn handle_text(&mut self, frame: Frame, compressed: bool) -> Result<Option<Message>> {
968 if self.fragment_opcode.is_some() {
969 return Err(Error::Protocol("expected continuation frame"));
970 }
971
972 if frame.header.fin {
973 let payload = if compressed {
974 self.decoder
975 .decompress(&frame.payload, self.max_message_size)?
976 } else {
977 frame.payload
978 };
979
980 if !validate_utf8(&payload) {
981 return Err(Error::InvalidUtf8);
982 }
983 Ok(Some(Message::Text(payload)))
984 } else {
985 self.fragment_compressed = compressed;
986 self.start_fragment(OpCode::Text, frame.payload)?;
987 Ok(None)
988 }
989 }
990
991 fn handle_binary(&mut self, frame: Frame, compressed: bool) -> Result<Option<Message>> {
993 if self.fragment_opcode.is_some() {
994 return Err(Error::Protocol("expected continuation frame"));
995 }
996
997 if frame.header.fin {
998 let payload = if compressed {
999 self.decoder
1000 .decompress(&frame.payload, self.max_message_size)?
1001 } else {
1002 frame.payload
1003 };
1004 Ok(Some(Message::Binary(payload)))
1005 } else {
1006 self.fragment_compressed = compressed;
1007 self.start_fragment(OpCode::Binary, frame.payload)?;
1008 Ok(None)
1009 }
1010 }
1011
1012 fn handle_continuation(&mut self, frame: Frame) -> Result<Option<Message>> {
1014 let opcode = self
1015 .fragment_opcode
1016 .ok_or(Error::Protocol("unexpected continuation frame"))?;
1017
1018 let new_size = self.fragment_buf.len() + frame.payload.len();
1019 if new_size > self.max_message_size {
1020 return Err(Error::MessageTooLarge);
1021 }
1022
1023 self.fragment_buf.extend_from_slice(&frame.payload);
1024
1025 if frame.header.fin {
1026 self.complete_fragment(opcode)
1027 } else {
1028 Ok(None)
1029 }
1030 }
1031
1032 fn start_fragment(&mut self, opcode: OpCode, payload: Bytes) -> Result<()> {
1034 if payload.len() > self.max_message_size {
1035 return Err(Error::MessageTooLarge);
1036 }
1037
1038 self.fragment_opcode = Some(opcode);
1039 self.fragment_buf.clear();
1040 self.fragment_buf.extend_from_slice(&payload);
1041 Ok(())
1042 }
1043
1044 fn complete_fragment(&mut self, opcode: OpCode) -> Result<Option<Message>> {
1046 self.fragment_opcode = None;
1047 let compressed_data = self.fragment_buf.split().freeze();
1048
1049 let data = if self.fragment_compressed {
1050 self.fragment_compressed = false;
1051 self.decoder
1052 .decompress(&compressed_data, self.max_message_size)?
1053 } else {
1054 compressed_data
1055 };
1056
1057 match opcode {
1058 OpCode::Text => {
1059 if !validate_utf8(&data) {
1060 return Err(Error::InvalidUtf8);
1061 }
1062 Ok(Some(Message::Text(data)))
1063 }
1064 OpCode::Binary => Ok(Some(Message::Binary(data))),
1065 _ => Err(Error::Protocol("invalid fragment opcode")),
1066 }
1067 }
1068
1069 fn handle_close(&mut self, frame: Frame) -> Result<Option<Message>> {
1071 let reason = if frame.payload.len() >= 2 {
1072 let code = u16::from_be_bytes([frame.payload[0], frame.payload[1]]);
1073
1074 if !CloseReason::is_valid_code(code) && !(3000..=4999).contains(&code) {
1075 return Err(Error::InvalidCloseCode(code));
1076 }
1077
1078 let reason_text = if frame.payload.len() > 2 {
1079 let text = &frame.payload[2..];
1080 if !validate_utf8(text) {
1081 return Err(Error::InvalidUtf8);
1082 }
1083 String::from_utf8_lossy(text).into_owned()
1084 } else {
1085 String::new()
1086 };
1087
1088 Some(CloseReason::new(code, reason_text))
1089 } else if frame.payload.is_empty() {
1090 None
1091 } else {
1092 return Err(Error::Protocol("invalid close frame payload"));
1093 };
1094
1095 Ok(Some(Message::Close(reason)))
1096 }
1097}
1098
1099#[cfg(feature = "permessage-deflate")]
1103pub struct CompressedWriterProtocol {
1104 role: Role,
1106 encoder: crate::deflate::DeflateEncoder,
1108}
1109
1110#[cfg(feature = "permessage-deflate")]
1111impl CompressedWriterProtocol {
1112 pub fn server(config: &DeflateConfig) -> Self {
1114 Self {
1115 role: Role::Server,
1116 encoder: crate::deflate::DeflateEncoder::new(
1117 config.server_max_window_bits,
1118 config.server_no_context_takeover,
1119 config.compression_level,
1120 config.compression_threshold,
1121 ),
1122 }
1123 }
1124
1125 pub fn client(config: &DeflateConfig) -> Self {
1127 Self {
1128 role: Role::Client,
1129 encoder: crate::deflate::DeflateEncoder::new(
1130 config.client_max_window_bits,
1131 config.client_no_context_takeover,
1132 config.compression_level,
1133 config.compression_threshold,
1134 ),
1135 }
1136 }
1137
1138 pub fn encode_message(&mut self, msg: &Message, buf: &mut BytesMut) -> Result<()> {
1140 let mask = if self.role == Role::Client {
1141 Some(crate::mask::generate_mask_fast())
1142 } else {
1143 None
1144 };
1145
1146 match msg {
1147 Message::Text(b) => {
1148 if let Some(compressed) = self.encoder.compress(b)? {
1149 encode_frame_with_rsv(buf, OpCode::Text, &compressed, true, mask, true);
1150 } else {
1151 encode_frame(buf, OpCode::Text, b, true, mask);
1152 }
1153 }
1154 Message::Binary(b) => {
1155 if let Some(compressed) = self.encoder.compress(b)? {
1156 encode_frame_with_rsv(buf, OpCode::Binary, &compressed, true, mask, true);
1157 } else {
1158 encode_frame(buf, OpCode::Binary, b, true, mask);
1159 }
1160 }
1161 Message::Ping(b) => {
1162 encode_frame(buf, OpCode::Ping, b, true, mask);
1163 }
1164 Message::Pong(b) => {
1165 encode_frame(buf, OpCode::Pong, b, true, mask);
1166 }
1167 Message::Close(reason) => {
1168 let payload = if let Some(r) = reason {
1169 let mut p = BytesMut::with_capacity(2 + r.reason.len());
1170 p.extend_from_slice(&r.code.to_be_bytes());
1171 p.extend_from_slice(r.reason.as_bytes());
1172 p.freeze()
1173 } else {
1174 Bytes::new()
1175 };
1176 encode_frame(buf, OpCode::Close, &payload, true, mask);
1177 }
1178 }
1179
1180 Ok(())
1181 }
1182
1183 pub fn encode_pong(&mut self, ping_data: &[u8], buf: &mut BytesMut) {
1185 let mask = if self.role == Role::Client {
1186 Some(crate::mask::generate_mask_fast())
1187 } else {
1188 None
1189 };
1190 encode_frame(buf, OpCode::Pong, ping_data, true, mask);
1191 }
1192
1193 pub fn encode_close_response(&mut self, buf: &mut BytesMut) {
1195 let mask = if self.role == Role::Client {
1196 Some(crate::mask::generate_mask_fast())
1197 } else {
1198 None
1199 };
1200 encode_frame(buf, OpCode::Close, &[], true, mask);
1201 }
1202}
1203
1204#[cfg(test)]
1205mod tests {
1206 use super::*;
1207
1208 #[test]
1209 fn test_message_text() {
1210 let mut protocol = Protocol::new(Role::Server, 1024 * 1024, 64 * 1024 * 1024);
1211
1212 let mut buf = BytesMut::new();
1214 buf.extend_from_slice(&[0x81, 0x85]); buf.extend_from_slice(&[0x37, 0xfa, 0x21, 0x3d]); let mut payload = *b"Hello";
1219 crate::simd::apply_mask(&mut payload, [0x37, 0xfa, 0x21, 0x3d]);
1220 buf.extend_from_slice(&payload);
1221
1222 let messages = protocol.process(&mut buf).unwrap();
1223 assert_eq!(messages.len(), 1);
1224
1225 if let Message::Text(s) = &messages[0] {
1226 assert_eq!(s, "Hello");
1227 } else {
1228 panic!("Expected text message");
1229 }
1230 }
1231
1232 #[test]
1233 fn test_fragmented_message() {
1234 let mut protocol = Protocol::new(Role::Server, 1024 * 1024, 64 * 1024 * 1024);
1235
1236 let mut buf = BytesMut::new();
1238 buf.extend_from_slice(&[0x01, 0x83]); buf.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); buf.extend_from_slice(b"Hel");
1241
1242 let messages = protocol.process(&mut buf).unwrap();
1243 assert!(messages.is_empty());
1244
1245 buf.extend_from_slice(&[0x80, 0x82]); buf.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); buf.extend_from_slice(b"lo");
1249
1250 let messages = protocol.process(&mut buf).unwrap();
1251 assert_eq!(messages.len(), 1);
1252
1253 if let Message::Text(s) = &messages[0] {
1254 assert_eq!(s, "Hello");
1255 } else {
1256 panic!("Expected text message");
1257 }
1258 }
1259
1260 #[test]
1261 fn test_encode_message() {
1262 let mut protocol = Protocol::new(Role::Server, 1024 * 1024, 64 * 1024 * 1024);
1263 let mut buf = BytesMut::new();
1264
1265 protocol
1266 .encode_message(&Message::text("test"), &mut buf)
1267 .unwrap();
1268
1269 assert_eq!(buf[0], 0x81); assert_eq!(buf[1], 0x04); assert_eq!(&buf[2..], b"test");
1272 }
1273}