sockudo_ws/
protocol.rs

1//! WebSocket protocol implementation
2//!
3//! This module handles the WebSocket protocol state machine, including:
4//! - Message fragmentation and reassembly
5//! - Control frame handling (ping/pong/close)
6//! - State transitions
7
8use bytes::{Bytes, BytesMut};
9
10use crate::error::{CloseReason, Error, Result};
11#[cfg(feature = "permessage-deflate")]
12use crate::frame::encode_frame_with_rsv;
13use crate::frame::{Frame, FrameParser, OpCode, encode_frame};
14use crate::utf8::{validate_utf8, validate_utf8_incomplete};
15
16#[cfg(feature = "permessage-deflate")]
17use crate::deflate::{DeflateConfig, DeflateContext};
18
19/// WebSocket endpoint role
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum Role {
22    /// Client (must mask frames)
23    Client,
24    /// Server (must not mask frames)
25    Server,
26}
27
28/// WebSocket message (complete, possibly assembled from fragments)
29///
30/// Text messages use `Bytes` internally for zero-copy efficiency.
31/// The payload is UTF-8 validated during parsing, so `as_text()` is safe.
32#[derive(Debug, Clone)]
33pub enum Message {
34    /// Text message (UTF-8 validated, stored as Bytes for zero-copy)
35    Text(Bytes),
36    /// Binary message
37    Binary(Bytes),
38    /// Ping message
39    Ping(Bytes),
40    /// Pong message
41    Pong(Bytes),
42    /// Close message
43    Close(Option<CloseReason>),
44}
45
46impl Message {
47    /// Create a text message from a string
48    #[inline]
49    pub fn text(s: impl Into<String>) -> Self {
50        Message::Text(Bytes::from(s.into()))
51    }
52
53    /// Create a binary message
54    #[inline]
55    pub fn binary(data: impl Into<Bytes>) -> Self {
56        Message::Binary(data.into())
57    }
58
59    /// Create a ping message
60    #[inline]
61    pub fn ping(data: impl Into<Bytes>) -> Self {
62        Message::Ping(data.into())
63    }
64
65    /// Create a pong message
66    #[inline]
67    pub fn pong(data: impl Into<Bytes>) -> Self {
68        Message::Pong(data.into())
69    }
70
71    /// Check if this is a close message
72    #[inline]
73    pub fn is_close(&self) -> bool {
74        matches!(self, Message::Close(_))
75    }
76
77    /// Check if this is a text message
78    #[inline]
79    pub fn is_text(&self) -> bool {
80        matches!(self, Message::Text(_))
81    }
82
83    /// Check if this is a binary message
84    #[inline]
85    pub fn is_binary(&self) -> bool {
86        matches!(self, Message::Binary(_))
87    }
88
89    /// Check if this is a ping message
90    #[inline]
91    pub fn is_ping(&self) -> bool {
92        matches!(self, Message::Ping(_))
93    }
94
95    /// Check if this is a pong message
96    #[inline]
97    pub fn is_pong(&self) -> bool {
98        matches!(self, Message::Pong(_))
99    }
100
101    /// Check if this is a control message
102    #[inline]
103    pub fn is_control(&self) -> bool {
104        matches!(
105            self,
106            Message::Ping(_) | Message::Pong(_) | Message::Close(_)
107        )
108    }
109
110    /// Get message as text (returns None for non-text messages)
111    ///
112    /// This is zero-copy - it returns a reference to the underlying bytes.
113    /// The text is guaranteed to be valid UTF-8 as it was validated during parsing.
114    #[inline]
115    pub fn as_text(&self) -> Option<&str> {
116        match self {
117            Message::Text(b) => {
118                // SAFETY: Text messages are UTF-8 validated during parsing
119                Some(unsafe { std::str::from_utf8_unchecked(b) })
120            }
121            _ => None,
122        }
123    }
124
125    /// Get message as bytes
126    #[inline]
127    pub fn as_bytes(&self) -> &[u8] {
128        match self {
129            Message::Text(b) => b,
130            Message::Binary(b) => b,
131            Message::Ping(b) => b,
132            Message::Pong(b) => b,
133            Message::Close(_) => &[],
134        }
135    }
136
137    /// Convert to text message (allocates a String)
138    ///
139    /// Returns None for non-text messages.
140    pub fn into_text(self) -> Option<String> {
141        match self {
142            Message::Text(b) => {
143                // SAFETY: Text messages are UTF-8 validated during parsing
144                Some(unsafe { String::from_utf8_unchecked(b.to_vec()) })
145            }
146            _ => None,
147        }
148    }
149
150    /// Get the underlying Bytes for text messages (zero-copy)
151    ///
152    /// Returns None for non-text messages.
153    #[inline]
154    pub fn text_bytes(&self) -> Option<&Bytes> {
155        match self {
156            Message::Text(b) => Some(b),
157            _ => None,
158        }
159    }
160
161    /// Convert to binary data
162    pub fn into_bytes(self) -> Bytes {
163        match self {
164            Message::Text(b) => b,
165            Message::Binary(b) => b,
166            Message::Ping(b) => b,
167            Message::Pong(b) => b,
168            Message::Close(_) => Bytes::new(),
169        }
170    }
171}
172
173impl From<String> for Message {
174    fn from(s: String) -> Self {
175        Message::Text(Bytes::from(s))
176    }
177}
178
179impl From<&str> for Message {
180    fn from(s: &str) -> Self {
181        Message::Text(Bytes::copy_from_slice(s.as_bytes()))
182    }
183}
184
185impl From<Vec<u8>> for Message {
186    fn from(v: Vec<u8>) -> Self {
187        Message::Binary(Bytes::from(v))
188    }
189}
190
191impl From<Bytes> for Message {
192    fn from(b: Bytes) -> Self {
193        Message::Binary(b)
194    }
195}
196
197impl From<&[u8]> for Message {
198    fn from(b: &[u8]) -> Self {
199        Message::Binary(Bytes::copy_from_slice(b))
200    }
201}
202
203/// Protocol state
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205enum State {
206    /// Normal operation
207    Open,
208    /// Close frame sent, waiting for response
209    CloseSent,
210    /// Close frame received, sent response
211    CloseReceived,
212    /// Connection closed
213    Closed,
214}
215
216/// WebSocket protocol handler
217///
218/// Handles frame parsing, message assembly, and control frame processing.
219pub struct Protocol {
220    /// Endpoint role
221    pub(crate) role: Role,
222    /// Current state
223    state: State,
224    /// Frame parser
225    pub(crate) parser: FrameParser,
226    /// Fragment buffer for message reassembly
227    pub(crate) fragment_buf: BytesMut,
228    /// Opcode of current fragmented message
229    pub(crate) fragment_opcode: Option<OpCode>,
230    /// Maximum message size
231    pub(crate) max_message_size: usize,
232    /// Pending close reason (if we received a close frame)
233    pending_close: Option<CloseReason>,
234}
235
236impl Protocol {
237    /// Create a new protocol handler
238    pub fn new(role: Role, max_frame_size: usize, max_message_size: usize) -> Self {
239        let expect_masked = role == Role::Server;
240
241        Self {
242            role,
243            state: State::Open,
244            parser: FrameParser::new(max_frame_size, expect_masked),
245            fragment_buf: BytesMut::new(),
246            fragment_opcode: None,
247            max_message_size,
248            pending_close: None,
249        }
250    }
251
252    /// Check if connection is closed
253    #[inline]
254    pub fn is_closed(&self) -> bool {
255        self.state == State::Closed
256    }
257
258    /// Check if we're in the closing handshake
259    #[inline]
260    pub fn is_closing(&self) -> bool {
261        matches!(self.state, State::CloseSent | State::CloseReceived)
262    }
263
264    /// Process incoming data and return complete messages
265    ///
266    /// This may return multiple messages if the buffer contains multiple frames.
267    /// Control frames are processed inline and may be interspersed with data frames.
268    #[inline]
269    pub fn process(&mut self, buf: &mut BytesMut) -> Result<Vec<Message>> {
270        let mut messages = Vec::new();
271        self.process_into(buf, &mut messages)?;
272        Ok(messages)
273    }
274
275    /// Process incoming data into a reusable message buffer (zero-allocation hot path)
276    ///
277    /// This variant allows reusing a Vec<Message> across calls to avoid allocations.
278    #[inline]
279    pub fn process_into(&mut self, buf: &mut BytesMut, messages: &mut Vec<Message>) -> Result<()> {
280        messages.clear();
281
282        while !buf.is_empty() {
283            match self.parser.parse(buf)? {
284                Some(frame) => {
285                    if let Some(msg) = self.handle_frame(frame)? {
286                        messages.push(msg);
287                    }
288                }
289                None => break,
290            }
291        }
292
293        Ok(())
294    }
295
296    /// Handle a single parsed frame
297    fn handle_frame(&mut self, frame: Frame) -> Result<Option<Message>> {
298        match frame.header.opcode {
299            OpCode::Continuation => self.handle_continuation(frame),
300            OpCode::Text => self.handle_text(frame),
301            OpCode::Binary => self.handle_binary(frame),
302            OpCode::Close => self.handle_close(frame),
303            OpCode::Ping => self.handle_ping(frame),
304            OpCode::Pong => self.handle_pong(frame),
305        }
306    }
307
308    /// Handle text frame
309    fn handle_text(&mut self, frame: Frame) -> Result<Option<Message>> {
310        if self.fragment_opcode.is_some() {
311            return Err(Error::Protocol("expected continuation frame"));
312        }
313
314        if frame.header.fin {
315            // Complete message in one frame (fast path)
316            if !validate_utf8(&frame.payload) {
317                return Err(Error::InvalidUtf8);
318            }
319            // Zero-copy: just return the Bytes directly (already UTF-8 validated)
320            Ok(Some(Message::Text(frame.payload)))
321        } else {
322            // Start of fragmented message
323            self.start_fragment(OpCode::Text, frame.payload)?;
324            Ok(None)
325        }
326    }
327
328    /// Handle binary frame
329    fn handle_binary(&mut self, frame: Frame) -> Result<Option<Message>> {
330        if self.fragment_opcode.is_some() {
331            return Err(Error::Protocol("expected continuation frame"));
332        }
333
334        if frame.header.fin {
335            // Complete message in one frame (fast path)
336            Ok(Some(Message::Binary(frame.payload)))
337        } else {
338            // Start of fragmented message
339            self.start_fragment(OpCode::Binary, frame.payload)?;
340            Ok(None)
341        }
342    }
343
344    /// Handle continuation frame
345    fn handle_continuation(&mut self, frame: Frame) -> Result<Option<Message>> {
346        let opcode = self
347            .fragment_opcode
348            .ok_or(Error::Protocol("unexpected continuation frame"))?;
349
350        // Check total size
351        let new_size = self.fragment_buf.len() + frame.payload.len();
352        if new_size > self.max_message_size {
353            return Err(Error::MessageTooLarge);
354        }
355
356        self.fragment_buf.extend_from_slice(&frame.payload);
357
358        if frame.header.fin {
359            // Complete the fragmented message
360            self.complete_fragment(opcode)
361        } else {
362            // Validate partial UTF-8 for text messages
363            if opcode == OpCode::Text {
364                let (valid, _incomplete) = validate_utf8_incomplete(&self.fragment_buf);
365                if !valid {
366                    return Err(Error::InvalidUtf8);
367                }
368            }
369            Ok(None)
370        }
371    }
372
373    /// Start a fragmented message
374    pub(crate) fn start_fragment(&mut self, opcode: OpCode, payload: Bytes) -> Result<()> {
375        if payload.len() > self.max_message_size {
376            return Err(Error::MessageTooLarge);
377        }
378
379        self.fragment_opcode = Some(opcode);
380        self.fragment_buf.clear();
381        self.fragment_buf.extend_from_slice(&payload);
382
383        // Validate partial UTF-8 for text messages
384        if opcode == OpCode::Text {
385            let (valid, _incomplete) = validate_utf8_incomplete(&self.fragment_buf);
386            if !valid {
387                return Err(Error::InvalidUtf8);
388            }
389        }
390
391        Ok(())
392    }
393
394    /// Complete a fragmented message
395    fn complete_fragment(&mut self, opcode: OpCode) -> Result<Option<Message>> {
396        self.fragment_opcode = None;
397        let data = self.fragment_buf.split().freeze();
398
399        match opcode {
400            OpCode::Text => {
401                if !validate_utf8(&data) {
402                    return Err(Error::InvalidUtf8);
403                }
404                // Zero-copy: just return the Bytes directly (already UTF-8 validated)
405                Ok(Some(Message::Text(data)))
406            }
407            OpCode::Binary => Ok(Some(Message::Binary(data))),
408            _ => Err(Error::Protocol("invalid fragment opcode")),
409        }
410    }
411
412    /// Handle close frame
413    pub(crate) fn handle_close(&mut self, frame: Frame) -> Result<Option<Message>> {
414        let reason = if frame.payload.len() >= 2 {
415            let code = u16::from_be_bytes([frame.payload[0], frame.payload[1]]);
416
417            // Validate close code
418            if !CloseReason::is_valid_code(code) && !(3000..=4999).contains(&code) {
419                return Err(Error::InvalidCloseCode(code));
420            }
421
422            let reason_text = if frame.payload.len() > 2 {
423                let text = &frame.payload[2..];
424                if !validate_utf8(text) {
425                    return Err(Error::InvalidUtf8);
426                }
427                String::from_utf8_lossy(text).into_owned()
428            } else {
429                String::new()
430            };
431
432            Some(CloseReason::new(code, reason_text))
433        } else if frame.payload.is_empty() {
434            None
435        } else {
436            // Close frame with 1 byte payload is invalid
437            return Err(Error::Protocol("invalid close frame payload"));
438        };
439
440        match self.state {
441            State::Open => {
442                self.state = State::CloseReceived;
443                self.pending_close = reason.clone();
444            }
445            State::CloseSent => {
446                self.state = State::Closed;
447            }
448            _ => {}
449        }
450
451        Ok(Some(Message::Close(reason)))
452    }
453
454    /// Handle ping frame
455    pub(crate) fn handle_ping(&mut self, frame: Frame) -> Result<Option<Message>> {
456        Ok(Some(Message::Ping(frame.payload)))
457    }
458
459    /// Handle pong frame
460    pub(crate) fn handle_pong(&mut self, frame: Frame) -> Result<Option<Message>> {
461        Ok(Some(Message::Pong(frame.payload)))
462    }
463
464    /// Encode a message for sending
465    pub fn encode_message(&mut self, msg: &Message, buf: &mut BytesMut) -> Result<()> {
466        let mask = if self.role == Role::Client {
467            Some(crate::mask::generate_mask_fast())
468        } else {
469            None
470        };
471
472        match msg {
473            Message::Text(b) => {
474                encode_frame(buf, OpCode::Text, b, true, mask);
475            }
476            Message::Binary(b) => {
477                encode_frame(buf, OpCode::Binary, b, true, mask);
478            }
479            Message::Ping(b) => {
480                encode_frame(buf, OpCode::Ping, b, true, mask);
481            }
482            Message::Pong(b) => {
483                encode_frame(buf, OpCode::Pong, b, true, mask);
484            }
485            Message::Close(reason) => {
486                if self.state == State::Open {
487                    self.state = State::CloseSent;
488                }
489
490                let payload = if let Some(r) = reason {
491                    let mut p = BytesMut::with_capacity(2 + r.reason.len());
492                    p.extend_from_slice(&r.code.to_be_bytes());
493                    p.extend_from_slice(r.reason.as_bytes());
494                    p.freeze()
495                } else {
496                    Bytes::new()
497                };
498
499                encode_frame(buf, OpCode::Close, &payload, true, mask);
500            }
501        }
502
503        Ok(())
504    }
505
506    /// Encode a pong response for a ping
507    pub fn encode_pong(&mut self, ping_data: &[u8], buf: &mut BytesMut) {
508        let mask = if self.role == Role::Client {
509            Some(crate::mask::generate_mask_fast())
510        } else {
511            None
512        };
513        encode_frame(buf, OpCode::Pong, ping_data, true, mask);
514    }
515
516    /// Encode a close response
517    pub fn encode_close_response(&mut self, buf: &mut BytesMut) {
518        let mask = if self.role == Role::Client {
519            Some(crate::mask::generate_mask_fast())
520        } else {
521            None
522        };
523
524        let payload = if let Some(ref reason) = self.pending_close {
525            let mut p = BytesMut::with_capacity(2 + reason.reason.len());
526            p.extend_from_slice(&reason.code.to_be_bytes());
527            p.extend_from_slice(reason.reason.as_bytes());
528            p.freeze()
529        } else {
530            Bytes::new()
531        };
532
533        encode_frame(buf, OpCode::Close, &payload, true, mask);
534
535        if self.state == State::CloseReceived {
536            self.state = State::Closed;
537        }
538    }
539
540    /// Enable compression support in the frame parser
541    #[cfg(feature = "permessage-deflate")]
542    pub fn enable_compression(&mut self) {
543        self.parser.set_compression(true);
544    }
545}
546
547/// WebSocket protocol handler with permessage-deflate compression (RFC 7692)
548#[cfg(feature = "permessage-deflate")]
549pub struct CompressedProtocol {
550    /// Base protocol handler
551    inner: Protocol,
552    /// Deflate compression context
553    deflate: DeflateContext,
554    /// Whether the current fragmented message is compressed
555    fragment_compressed: bool,
556    /// Buffer for decompressed fragment data
557    decompress_buf: BytesMut,
558}
559
560#[cfg(feature = "permessage-deflate")]
561impl CompressedProtocol {
562    /// Create a new compressed protocol handler for server role
563    pub fn server(max_frame_size: usize, max_message_size: usize, config: DeflateConfig) -> Self {
564        let mut inner = Protocol::new(Role::Server, max_frame_size, max_message_size);
565        inner.enable_compression();
566
567        Self {
568            inner,
569            deflate: DeflateContext::server(config),
570            fragment_compressed: false,
571            decompress_buf: BytesMut::new(),
572        }
573    }
574
575    /// Create a new compressed protocol handler for client role
576    pub fn client(max_frame_size: usize, max_message_size: usize, config: DeflateConfig) -> Self {
577        let mut inner = Protocol::new(Role::Client, max_frame_size, max_message_size);
578        inner.enable_compression();
579
580        Self {
581            inner,
582            deflate: DeflateContext::client(config),
583            fragment_compressed: false,
584            decompress_buf: BytesMut::new(),
585        }
586    }
587
588    /// Check if connection is closed
589    #[inline]
590    pub fn is_closed(&self) -> bool {
591        self.inner.is_closed()
592    }
593
594    /// Check if we're in the closing handshake
595    #[inline]
596    pub fn is_closing(&self) -> bool {
597        self.inner.is_closing()
598    }
599
600    /// Process incoming data and return complete messages
601    pub fn process(&mut self, buf: &mut BytesMut) -> Result<Vec<Message>> {
602        let mut messages = Vec::new();
603        self.process_into(buf, &mut messages)?;
604        Ok(messages)
605    }
606
607    /// Process incoming data into a reusable message buffer
608    #[inline]
609    pub fn process_into(&mut self, buf: &mut BytesMut, messages: &mut Vec<Message>) -> Result<()> {
610        const DEBUG: bool = false;
611        messages.clear();
612
613        while !buf.is_empty() {
614            if DEBUG {
615                eprintln!("[PROTOCOL] process_into loop: buf has {} bytes", buf.len());
616            }
617            match self.inner.parser.parse(buf)? {
618                Some(frame) => {
619                    if DEBUG {
620                        eprintln!("[PROTOCOL] Parsed frame, handling...");
621                    }
622                    if let Some(msg) = self.handle_frame(frame)? {
623                        messages.push(msg);
624                        if DEBUG {
625                            eprintln!("[PROTOCOL] Added message to output");
626                        }
627                    } else if DEBUG {
628                        eprintln!("[PROTOCOL] No message from handle_frame (fragment or control)");
629                    }
630                }
631                None => {
632                    if DEBUG {
633                        eprintln!("[PROTOCOL] Parser returned None, breaking loop");
634                    }
635                    break;
636                }
637            }
638        }
639
640        if DEBUG {
641            eprintln!("[PROTOCOL] process_into done, {} messages", messages.len());
642        }
643
644        Ok(())
645    }
646
647    /// Handle a parsed frame with decompression support
648    fn handle_frame(&mut self, frame: Frame) -> Result<Option<Message>> {
649        let is_compressed = frame.header.rsv1;
650
651        match frame.header.opcode {
652            OpCode::Continuation => self.handle_continuation(frame),
653            OpCode::Text => self.handle_text(frame, is_compressed),
654            OpCode::Binary => self.handle_binary(frame, is_compressed),
655            OpCode::Close => self.inner.handle_close(frame),
656            OpCode::Ping => self.inner.handle_ping(frame),
657            OpCode::Pong => self.inner.handle_pong(frame),
658        }
659    }
660
661    /// Handle text frame with potential decompression
662    fn handle_text(&mut self, frame: Frame, compressed: bool) -> Result<Option<Message>> {
663        const DEBUG: bool = false;
664
665        if self.inner.fragment_opcode.is_some() {
666            return Err(Error::Protocol("expected continuation frame"));
667        }
668
669        if frame.header.fin {
670            // Complete message in one frame
671            if DEBUG {
672                eprintln!(
673                    "[PROTOCOL] Complete text message, compressed={}, size={}",
674                    compressed,
675                    frame.payload.len()
676                );
677            }
678            let payload = if compressed {
679                self.deflate
680                    .decompress(&frame.payload, self.inner.max_message_size)?
681            } else {
682                frame.payload
683            };
684
685            if !validate_utf8(&payload) {
686                return Err(Error::InvalidUtf8);
687            }
688            // Zero-copy: just return the Bytes directly (already UTF-8 validated)
689            Ok(Some(Message::Text(payload)))
690        } else {
691            // Start of fragmented message
692            if DEBUG {
693                eprintln!(
694                    "[PROTOCOL] Starting fragmented text message, compressed={}, first fragment size={}",
695                    compressed,
696                    frame.payload.len()
697                );
698            }
699            self.fragment_compressed = compressed;
700
701            // For compressed fragments, store as binary (don't validate UTF-8 yet)
702            // We'll decompress and validate when the message is complete
703            if compressed {
704                self.inner.start_fragment(OpCode::Binary, frame.payload)?;
705                // Override the opcode back to Text for proper handling
706                self.inner.fragment_opcode = Some(OpCode::Text);
707            } else {
708                self.inner.start_fragment(OpCode::Text, frame.payload)?;
709            }
710            Ok(None)
711        }
712    }
713
714    /// Handle binary frame with potential decompression
715    fn handle_binary(&mut self, frame: Frame, compressed: bool) -> Result<Option<Message>> {
716        if self.inner.fragment_opcode.is_some() {
717            return Err(Error::Protocol("expected continuation frame"));
718        }
719
720        if frame.header.fin {
721            // Complete message in one frame
722            let payload = if compressed {
723                self.deflate
724                    .decompress(&frame.payload, self.inner.max_message_size)?
725            } else {
726                frame.payload
727            };
728            Ok(Some(Message::Binary(payload)))
729        } else {
730            // Start of fragmented message
731            self.fragment_compressed = compressed;
732            self.inner.start_fragment(OpCode::Binary, frame.payload)?;
733            Ok(None)
734        }
735    }
736
737    /// Handle continuation frame
738    fn handle_continuation(&mut self, frame: Frame) -> Result<Option<Message>> {
739        let opcode = self
740            .inner
741            .fragment_opcode
742            .ok_or(Error::Protocol("unexpected continuation frame"))?;
743
744        // Check total size
745        let new_size = self.inner.fragment_buf.len() + frame.payload.len();
746        if new_size > self.inner.max_message_size {
747            return Err(Error::MessageTooLarge);
748        }
749
750        self.inner.fragment_buf.extend_from_slice(&frame.payload);
751
752        if frame.header.fin {
753            // Complete the fragmented message
754            self.complete_fragment(opcode)
755        } else {
756            Ok(None)
757        }
758    }
759
760    /// Complete a fragmented message with decompression if needed
761    fn complete_fragment(&mut self, opcode: OpCode) -> Result<Option<Message>> {
762        self.inner.fragment_opcode = None;
763        let compressed_data = self.inner.fragment_buf.split().freeze();
764
765        // Decompress if the first frame had RSV1 set
766        let data = if self.fragment_compressed {
767            self.fragment_compressed = false;
768            self.deflate
769                .decompress(&compressed_data, self.inner.max_message_size)?
770        } else {
771            compressed_data
772        };
773
774        match opcode {
775            OpCode::Text => {
776                if !validate_utf8(&data) {
777                    return Err(Error::InvalidUtf8);
778                }
779                // Zero-copy: just return the Bytes directly (already UTF-8 validated)
780                Ok(Some(Message::Text(data)))
781            }
782            OpCode::Binary => Ok(Some(Message::Binary(data))),
783            _ => Err(Error::Protocol("invalid fragment opcode")),
784        }
785    }
786
787    /// Encode a message for sending with compression
788    pub fn encode_message(&mut self, msg: &Message, buf: &mut BytesMut) -> Result<()> {
789        let mask = if self.inner.role == Role::Client {
790            Some(crate::mask::generate_mask_fast())
791        } else {
792            None
793        };
794
795        match msg {
796            Message::Text(b) => {
797                // Try to compress
798                if let Some(compressed) = self.deflate.compress(b)? {
799                    encode_frame_with_rsv(buf, OpCode::Text, &compressed, true, mask, true);
800                } else {
801                    encode_frame(buf, OpCode::Text, b, true, mask);
802                }
803            }
804            Message::Binary(b) => {
805                // Try to compress
806                if let Some(compressed) = self.deflate.compress(b)? {
807                    encode_frame_with_rsv(buf, OpCode::Binary, &compressed, true, mask, true);
808                } else {
809                    encode_frame(buf, OpCode::Binary, b, true, mask);
810                }
811            }
812            Message::Ping(b) => {
813                // Control frames are never compressed
814                encode_frame(buf, OpCode::Ping, b, true, mask);
815            }
816            Message::Pong(b) => {
817                encode_frame(buf, OpCode::Pong, b, true, mask);
818            }
819            Message::Close(_) => {
820                self.inner.encode_message(msg, buf)?;
821            }
822        }
823
824        Ok(())
825    }
826
827    /// Encode a pong response for a ping
828    pub fn encode_pong(&mut self, ping_data: &[u8], buf: &mut BytesMut) {
829        self.inner.encode_pong(ping_data, buf);
830    }
831
832    /// Encode a close response
833    pub fn encode_close_response(&mut self, buf: &mut BytesMut) {
834        self.inner.encode_close_response(buf);
835    }
836
837    /// Split the compressed protocol into separate reader and writer halves
838    ///
839    /// This allows the encoder and decoder to be used independently for
840    /// concurrent read/write operations.
841    pub fn split(
842        self,
843        max_frame_size: usize,
844        max_message_size: usize,
845    ) -> (CompressedReaderProtocol, CompressedWriterProtocol) {
846        let role = self.inner.role;
847
848        // Create fresh reader protocol (decoder state)
849        let reader = CompressedReaderProtocol {
850            role,
851            parser: FrameParser::new(max_frame_size, role == Role::Server),
852            fragment_buf: self.inner.fragment_buf,
853            fragment_opcode: self.inner.fragment_opcode,
854            max_message_size,
855            decoder: self.deflate.decoder,
856            fragment_compressed: self.fragment_compressed,
857        };
858
859        // Create fresh writer protocol (encoder state)
860        let writer = CompressedWriterProtocol {
861            role,
862            encoder: self.deflate.encoder,
863        };
864
865        (reader, writer)
866    }
867}
868
869/// Reader half of a split compressed protocol
870///
871/// Contains the decoder and frame parser for reading compressed messages.
872#[cfg(feature = "permessage-deflate")]
873pub struct CompressedReaderProtocol {
874    /// Endpoint role
875    role: Role,
876    /// Frame parser
877    parser: FrameParser,
878    /// Fragment buffer for message reassembly
879    fragment_buf: BytesMut,
880    /// Opcode of current fragmented message
881    fragment_opcode: Option<OpCode>,
882    /// Maximum message size
883    max_message_size: usize,
884    /// Deflate decoder
885    decoder: crate::deflate::DeflateDecoder,
886    /// Whether the current fragmented message is compressed
887    fragment_compressed: bool,
888}
889
890#[cfg(feature = "permessage-deflate")]
891impl CompressedReaderProtocol {
892    /// Create a new reader protocol for server role
893    pub fn server(max_frame_size: usize, max_message_size: usize, config: &DeflateConfig) -> Self {
894        Self {
895            role: Role::Server,
896            parser: FrameParser::new(max_frame_size, true),
897            fragment_buf: BytesMut::new(),
898            fragment_opcode: None,
899            max_message_size,
900            decoder: crate::deflate::DeflateDecoder::new(
901                config.client_max_window_bits,
902                config.client_no_context_takeover,
903            ),
904            fragment_compressed: false,
905        }
906    }
907
908    /// Create a new reader protocol for client role
909    pub fn client(max_frame_size: usize, max_message_size: usize, config: &DeflateConfig) -> Self {
910        Self {
911            role: Role::Client,
912            parser: FrameParser::new(max_frame_size, false),
913            fragment_buf: BytesMut::new(),
914            fragment_opcode: None,
915            max_message_size,
916            decoder: crate::deflate::DeflateDecoder::new(
917                config.server_max_window_bits,
918                config.server_no_context_takeover,
919            ),
920            fragment_compressed: false,
921        }
922    }
923
924    /// Process incoming data and return complete messages
925    pub fn process(&mut self, buf: &mut BytesMut) -> Result<Vec<Message>> {
926        let mut messages = Vec::new();
927        self.process_into(buf, &mut messages)?;
928        Ok(messages)
929    }
930
931    /// Process incoming data into a reusable message buffer
932    pub fn process_into(&mut self, buf: &mut BytesMut, messages: &mut Vec<Message>) -> Result<()> {
933        messages.clear();
934
935        // Enable compression in parser
936        self.parser.set_compression(true);
937
938        while !buf.is_empty() {
939            match self.parser.parse(buf)? {
940                Some(frame) => {
941                    if let Some(msg) = self.handle_frame(frame)? {
942                        messages.push(msg);
943                    }
944                }
945                None => break,
946            }
947        }
948
949        Ok(())
950    }
951
952    /// Handle a parsed frame with decompression support
953    fn handle_frame(&mut self, frame: Frame) -> Result<Option<Message>> {
954        let is_compressed = frame.header.rsv1;
955
956        match frame.header.opcode {
957            OpCode::Continuation => self.handle_continuation(frame),
958            OpCode::Text => self.handle_text(frame, is_compressed),
959            OpCode::Binary => self.handle_binary(frame, is_compressed),
960            OpCode::Close => self.handle_close(frame),
961            OpCode::Ping => Ok(Some(Message::Ping(frame.payload))),
962            OpCode::Pong => Ok(Some(Message::Pong(frame.payload))),
963        }
964    }
965
966    /// Handle text frame with potential decompression
967    fn handle_text(&mut self, frame: Frame, compressed: bool) -> Result<Option<Message>> {
968        if self.fragment_opcode.is_some() {
969            return Err(Error::Protocol("expected continuation frame"));
970        }
971
972        if frame.header.fin {
973            let payload = if compressed {
974                self.decoder
975                    .decompress(&frame.payload, self.max_message_size)?
976            } else {
977                frame.payload
978            };
979
980            if !validate_utf8(&payload) {
981                return Err(Error::InvalidUtf8);
982            }
983            Ok(Some(Message::Text(payload)))
984        } else {
985            self.fragment_compressed = compressed;
986            self.start_fragment(OpCode::Text, frame.payload)?;
987            Ok(None)
988        }
989    }
990
991    /// Handle binary frame with potential decompression
992    fn handle_binary(&mut self, frame: Frame, compressed: bool) -> Result<Option<Message>> {
993        if self.fragment_opcode.is_some() {
994            return Err(Error::Protocol("expected continuation frame"));
995        }
996
997        if frame.header.fin {
998            let payload = if compressed {
999                self.decoder
1000                    .decompress(&frame.payload, self.max_message_size)?
1001            } else {
1002                frame.payload
1003            };
1004            Ok(Some(Message::Binary(payload)))
1005        } else {
1006            self.fragment_compressed = compressed;
1007            self.start_fragment(OpCode::Binary, frame.payload)?;
1008            Ok(None)
1009        }
1010    }
1011
1012    /// Handle continuation frame
1013    fn handle_continuation(&mut self, frame: Frame) -> Result<Option<Message>> {
1014        let opcode = self
1015            .fragment_opcode
1016            .ok_or(Error::Protocol("unexpected continuation frame"))?;
1017
1018        let new_size = self.fragment_buf.len() + frame.payload.len();
1019        if new_size > self.max_message_size {
1020            return Err(Error::MessageTooLarge);
1021        }
1022
1023        self.fragment_buf.extend_from_slice(&frame.payload);
1024
1025        if frame.header.fin {
1026            self.complete_fragment(opcode)
1027        } else {
1028            Ok(None)
1029        }
1030    }
1031
1032    /// Start a fragmented message
1033    fn start_fragment(&mut self, opcode: OpCode, payload: Bytes) -> Result<()> {
1034        if payload.len() > self.max_message_size {
1035            return Err(Error::MessageTooLarge);
1036        }
1037
1038        self.fragment_opcode = Some(opcode);
1039        self.fragment_buf.clear();
1040        self.fragment_buf.extend_from_slice(&payload);
1041        Ok(())
1042    }
1043
1044    /// Complete a fragmented message with decompression if needed
1045    fn complete_fragment(&mut self, opcode: OpCode) -> Result<Option<Message>> {
1046        self.fragment_opcode = None;
1047        let compressed_data = self.fragment_buf.split().freeze();
1048
1049        let data = if self.fragment_compressed {
1050            self.fragment_compressed = false;
1051            self.decoder
1052                .decompress(&compressed_data, self.max_message_size)?
1053        } else {
1054            compressed_data
1055        };
1056
1057        match opcode {
1058            OpCode::Text => {
1059                if !validate_utf8(&data) {
1060                    return Err(Error::InvalidUtf8);
1061                }
1062                Ok(Some(Message::Text(data)))
1063            }
1064            OpCode::Binary => Ok(Some(Message::Binary(data))),
1065            _ => Err(Error::Protocol("invalid fragment opcode")),
1066        }
1067    }
1068
1069    /// Handle close frame
1070    fn handle_close(&mut self, frame: Frame) -> Result<Option<Message>> {
1071        let reason = if frame.payload.len() >= 2 {
1072            let code = u16::from_be_bytes([frame.payload[0], frame.payload[1]]);
1073
1074            if !CloseReason::is_valid_code(code) && !(3000..=4999).contains(&code) {
1075                return Err(Error::InvalidCloseCode(code));
1076            }
1077
1078            let reason_text = if frame.payload.len() > 2 {
1079                let text = &frame.payload[2..];
1080                if !validate_utf8(text) {
1081                    return Err(Error::InvalidUtf8);
1082                }
1083                String::from_utf8_lossy(text).into_owned()
1084            } else {
1085                String::new()
1086            };
1087
1088            Some(CloseReason::new(code, reason_text))
1089        } else if frame.payload.is_empty() {
1090            None
1091        } else {
1092            return Err(Error::Protocol("invalid close frame payload"));
1093        };
1094
1095        Ok(Some(Message::Close(reason)))
1096    }
1097}
1098
1099/// Writer half of a split compressed protocol
1100///
1101/// Contains the encoder for writing compressed messages.
1102#[cfg(feature = "permessage-deflate")]
1103pub struct CompressedWriterProtocol {
1104    /// Endpoint role
1105    role: Role,
1106    /// Deflate encoder
1107    encoder: crate::deflate::DeflateEncoder,
1108}
1109
1110#[cfg(feature = "permessage-deflate")]
1111impl CompressedWriterProtocol {
1112    /// Create a new writer protocol for server role
1113    pub fn server(config: &DeflateConfig) -> Self {
1114        Self {
1115            role: Role::Server,
1116            encoder: crate::deflate::DeflateEncoder::new(
1117                config.server_max_window_bits,
1118                config.server_no_context_takeover,
1119                config.compression_level,
1120                config.compression_threshold,
1121            ),
1122        }
1123    }
1124
1125    /// Create a new writer protocol for client role
1126    pub fn client(config: &DeflateConfig) -> Self {
1127        Self {
1128            role: Role::Client,
1129            encoder: crate::deflate::DeflateEncoder::new(
1130                config.client_max_window_bits,
1131                config.client_no_context_takeover,
1132                config.compression_level,
1133                config.compression_threshold,
1134            ),
1135        }
1136    }
1137
1138    /// Encode a message for sending with compression
1139    pub fn encode_message(&mut self, msg: &Message, buf: &mut BytesMut) -> Result<()> {
1140        let mask = if self.role == Role::Client {
1141            Some(crate::mask::generate_mask_fast())
1142        } else {
1143            None
1144        };
1145
1146        match msg {
1147            Message::Text(b) => {
1148                if let Some(compressed) = self.encoder.compress(b)? {
1149                    encode_frame_with_rsv(buf, OpCode::Text, &compressed, true, mask, true);
1150                } else {
1151                    encode_frame(buf, OpCode::Text, b, true, mask);
1152                }
1153            }
1154            Message::Binary(b) => {
1155                if let Some(compressed) = self.encoder.compress(b)? {
1156                    encode_frame_with_rsv(buf, OpCode::Binary, &compressed, true, mask, true);
1157                } else {
1158                    encode_frame(buf, OpCode::Binary, b, true, mask);
1159                }
1160            }
1161            Message::Ping(b) => {
1162                encode_frame(buf, OpCode::Ping, b, true, mask);
1163            }
1164            Message::Pong(b) => {
1165                encode_frame(buf, OpCode::Pong, b, true, mask);
1166            }
1167            Message::Close(reason) => {
1168                let payload = if let Some(r) = reason {
1169                    let mut p = BytesMut::with_capacity(2 + r.reason.len());
1170                    p.extend_from_slice(&r.code.to_be_bytes());
1171                    p.extend_from_slice(r.reason.as_bytes());
1172                    p.freeze()
1173                } else {
1174                    Bytes::new()
1175                };
1176                encode_frame(buf, OpCode::Close, &payload, true, mask);
1177            }
1178        }
1179
1180        Ok(())
1181    }
1182
1183    /// Encode a pong response for a ping
1184    pub fn encode_pong(&mut self, ping_data: &[u8], buf: &mut BytesMut) {
1185        let mask = if self.role == Role::Client {
1186            Some(crate::mask::generate_mask_fast())
1187        } else {
1188            None
1189        };
1190        encode_frame(buf, OpCode::Pong, ping_data, true, mask);
1191    }
1192
1193    /// Encode a close response
1194    pub fn encode_close_response(&mut self, buf: &mut BytesMut) {
1195        let mask = if self.role == Role::Client {
1196            Some(crate::mask::generate_mask_fast())
1197        } else {
1198            None
1199        };
1200        encode_frame(buf, OpCode::Close, &[], true, mask);
1201    }
1202}
1203
1204#[cfg(test)]
1205mod tests {
1206    use super::*;
1207
1208    #[test]
1209    fn test_message_text() {
1210        let mut protocol = Protocol::new(Role::Server, 1024 * 1024, 64 * 1024 * 1024);
1211
1212        // Simulate receiving a text frame
1213        let mut buf = BytesMut::new();
1214        buf.extend_from_slice(&[0x81, 0x85]); // FIN + Text, Masked + len 5
1215        buf.extend_from_slice(&[0x37, 0xfa, 0x21, 0x3d]); // Mask
1216
1217        // "Hello" XORed with mask
1218        let mut payload = *b"Hello";
1219        crate::simd::apply_mask(&mut payload, [0x37, 0xfa, 0x21, 0x3d]);
1220        buf.extend_from_slice(&payload);
1221
1222        let messages = protocol.process(&mut buf).unwrap();
1223        assert_eq!(messages.len(), 1);
1224
1225        if let Message::Text(s) = &messages[0] {
1226            assert_eq!(s, "Hello");
1227        } else {
1228            panic!("Expected text message");
1229        }
1230    }
1231
1232    #[test]
1233    fn test_fragmented_message() {
1234        let mut protocol = Protocol::new(Role::Server, 1024 * 1024, 64 * 1024 * 1024);
1235
1236        // First fragment
1237        let mut buf = BytesMut::new();
1238        buf.extend_from_slice(&[0x01, 0x83]); // Text (not FIN), Masked + len 3
1239        buf.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); // Zero mask
1240        buf.extend_from_slice(b"Hel");
1241
1242        let messages = protocol.process(&mut buf).unwrap();
1243        assert!(messages.is_empty());
1244
1245        // Final fragment
1246        buf.extend_from_slice(&[0x80, 0x82]); // Continuation + FIN, Masked + len 2
1247        buf.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); // Zero mask
1248        buf.extend_from_slice(b"lo");
1249
1250        let messages = protocol.process(&mut buf).unwrap();
1251        assert_eq!(messages.len(), 1);
1252
1253        if let Message::Text(s) = &messages[0] {
1254            assert_eq!(s, "Hello");
1255        } else {
1256            panic!("Expected text message");
1257        }
1258    }
1259
1260    #[test]
1261    fn test_encode_message() {
1262        let mut protocol = Protocol::new(Role::Server, 1024 * 1024, 64 * 1024 * 1024);
1263        let mut buf = BytesMut::new();
1264
1265        protocol
1266            .encode_message(&Message::text("test"), &mut buf)
1267            .unwrap();
1268
1269        assert_eq!(buf[0], 0x81); // FIN + Text
1270        assert_eq!(buf[1], 0x04); // Length 4 (no mask for server)
1271        assert_eq!(&buf[2..], b"test");
1272    }
1273}