sentinel_proxy/websocket/
codec.rs

1//! WebSocket frame codec implementing RFC 6455.
2//!
3//! Provides frame parsing and encoding for WebSocket connections.
4
5use bytes::{Buf, BufMut, BytesMut};
6use std::io;
7use tokio_util::codec::{Decoder, Encoder};
8use tracing::trace;
9
10/// Maximum allowed frame size (1MB default)
11pub const DEFAULT_MAX_FRAME_SIZE: usize = 1024 * 1024;
12
13/// WebSocket frame opcode
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum Opcode {
16    /// Continuation frame (0x0)
17    Continuation,
18    /// Text frame (0x1)
19    Text,
20    /// Binary frame (0x2)
21    Binary,
22    /// Connection close (0x8)
23    Close,
24    /// Ping (0x9)
25    Ping,
26    /// Pong (0xA)
27    Pong,
28    /// Reserved/unknown opcode
29    Reserved(u8),
30}
31
32impl Opcode {
33    /// Parse opcode from byte value
34    pub fn from_u8(value: u8) -> Self {
35        match value & 0x0F {
36            0x0 => Self::Continuation,
37            0x1 => Self::Text,
38            0x2 => Self::Binary,
39            0x8 => Self::Close,
40            0x9 => Self::Ping,
41            0xA => Self::Pong,
42            other => Self::Reserved(other),
43        }
44    }
45
46    /// Convert to byte value
47    pub fn as_u8(&self) -> u8 {
48        match self {
49            Self::Continuation => 0x0,
50            Self::Text => 0x1,
51            Self::Binary => 0x2,
52            Self::Close => 0x8,
53            Self::Ping => 0x9,
54            Self::Pong => 0xA,
55            Self::Reserved(v) => *v,
56        }
57    }
58
59    /// Convert to string representation
60    pub fn as_str(&self) -> &'static str {
61        match self {
62            Self::Continuation => "continuation",
63            Self::Text => "text",
64            Self::Binary => "binary",
65            Self::Close => "close",
66            Self::Ping => "ping",
67            Self::Pong => "pong",
68            Self::Reserved(_) => "reserved",
69        }
70    }
71
72    /// Check if this is a control frame
73    pub fn is_control(&self) -> bool {
74        matches!(self, Self::Close | Self::Ping | Self::Pong)
75    }
76
77    /// Check if this is a data frame
78    pub fn is_data(&self) -> bool {
79        matches!(self, Self::Continuation | Self::Text | Self::Binary)
80    }
81}
82
83/// A parsed WebSocket frame
84#[derive(Debug, Clone)]
85pub struct WebSocketFrame {
86    /// FIN bit - true if this is the final frame of a message
87    pub fin: bool,
88    /// Frame opcode
89    pub opcode: Opcode,
90    /// Masking key (only for client-to-server frames)
91    pub mask: Option<[u8; 4]>,
92    /// Frame payload data (unmasked)
93    pub payload: Vec<u8>,
94}
95
96impl WebSocketFrame {
97    /// Create a new frame
98    pub fn new(opcode: Opcode, payload: Vec<u8>) -> Self {
99        Self {
100            fin: true,
101            opcode,
102            mask: None,
103            payload,
104        }
105    }
106
107    /// Create a close frame
108    pub fn close(code: u16, reason: &str) -> Self {
109        let mut payload = Vec::with_capacity(2 + reason.len());
110        payload.extend_from_slice(&code.to_be_bytes());
111        payload.extend_from_slice(reason.as_bytes());
112        Self {
113            fin: true,
114            opcode: Opcode::Close,
115            mask: None,
116            payload,
117        }
118    }
119
120    /// Create a ping frame
121    pub fn ping(data: Vec<u8>) -> Self {
122        Self {
123            fin: true,
124            opcode: Opcode::Ping,
125            mask: None,
126            payload: data,
127        }
128    }
129
130    /// Create a pong frame
131    pub fn pong(data: Vec<u8>) -> Self {
132        Self {
133            fin: true,
134            opcode: Opcode::Pong,
135            mask: None,
136            payload: data,
137        }
138    }
139
140    /// Set the masking key (for client frames)
141    pub fn with_mask(mut self, mask: [u8; 4]) -> Self {
142        self.mask = Some(mask);
143        self
144    }
145
146    /// Set the FIN bit
147    pub fn with_fin(mut self, fin: bool) -> Self {
148        self.fin = fin;
149        self
150    }
151
152    /// Parse close code and reason from payload
153    pub fn close_code_and_reason(&self) -> Option<(u16, String)> {
154        if self.opcode != Opcode::Close || self.payload.len() < 2 {
155            return None;
156        }
157        let code = u16::from_be_bytes([self.payload[0], self.payload[1]]);
158        let reason = if self.payload.len() > 2 {
159            String::from_utf8_lossy(&self.payload[2..]).to_string()
160        } else {
161            String::new()
162        };
163        Some((code, reason))
164    }
165}
166
167/// WebSocket frame codec for tokio streams
168///
169/// Handles parsing and encoding of WebSocket frames per RFC 6455.
170pub struct WebSocketCodec {
171    /// Maximum allowed frame size
172    max_frame_size: usize,
173    /// Whether we expect masked frames (true for server receiving from client)
174    expect_masked: bool,
175    /// Whether we should mask outgoing frames (true for client sending to server)
176    mask_outgoing: bool,
177}
178
179impl WebSocketCodec {
180    /// Create a new codec with specified max frame size.
181    ///
182    /// Uses permissive settings for proxy use:
183    /// - Does not enforce masking (handles both masked and unmasked)
184    /// - Does not mask outgoing frames
185    pub fn new(max_frame_size: usize) -> Self {
186        Self {
187            max_frame_size,
188            expect_masked: false, // Permissive for proxy
189            mask_outgoing: false,
190        }
191    }
192
193    /// Create a new codec for server-side use (receiving client frames)
194    ///
195    /// - Expects masked frames from client
196    /// - Does not mask frames to client
197    pub fn server() -> Self {
198        Self {
199            max_frame_size: DEFAULT_MAX_FRAME_SIZE,
200            expect_masked: true,
201            mask_outgoing: false,
202        }
203    }
204
205    /// Create a new codec for client-side use (sending to server)
206    ///
207    /// - Does not expect masked frames from server
208    /// - Masks frames to server
209    pub fn client() -> Self {
210        Self {
211            max_frame_size: DEFAULT_MAX_FRAME_SIZE,
212            expect_masked: false,
213            mask_outgoing: true,
214        }
215    }
216
217    /// Set maximum frame size
218    pub fn with_max_frame_size(mut self, size: usize) -> Self {
219        self.max_frame_size = size;
220        self
221    }
222
223    /// Apply XOR mask to data in-place
224    fn apply_mask(data: &mut [u8], mask: [u8; 4]) {
225        for (i, byte) in data.iter_mut().enumerate() {
226            *byte ^= mask[i % 4];
227        }
228    }
229
230    /// Decode a frame from a byte slice, returning the frame and bytes consumed.
231    ///
232    /// This is a non-mutating version for use with the proxy handler.
233    /// Returns `Ok(None)` if more data is needed.
234    pub fn decode_frame(
235        &self,
236        src: &BytesMut,
237    ) -> Result<Option<(WebSocketFrame, usize)>, std::io::Error> {
238        // Need at least 2 bytes for the header
239        if src.len() < 2 {
240            return Ok(None);
241        }
242
243        // Parse first two bytes
244        let first_byte = src[0];
245        let second_byte = src[1];
246
247        let fin = (first_byte & 0x80) != 0;
248        let rsv = (first_byte & 0x70) >> 4;
249        let opcode = Opcode::from_u8(first_byte & 0x0F);
250        let masked = (second_byte & 0x80) != 0;
251        let payload_len_byte = second_byte & 0x7F;
252
253        // Check RSV bits (must be 0 unless extension negotiated)
254        if rsv != 0 {
255            return Err(std::io::Error::new(
256                std::io::ErrorKind::InvalidData,
257                "Non-zero RSV bits without extension",
258            ));
259        }
260
261        // Calculate header size and payload length
262        let (header_size, payload_len) = match payload_len_byte {
263            0..=125 => (2, payload_len_byte as usize),
264            126 => {
265                if src.len() < 4 {
266                    return Ok(None);
267                }
268                let len = u16::from_be_bytes([src[2], src[3]]) as usize;
269                (4, len)
270            }
271            127 => {
272                if src.len() < 10 {
273                    return Ok(None);
274                }
275                let len = u64::from_be_bytes([
276                    src[2], src[3], src[4], src[5], src[6], src[7], src[8], src[9],
277                ]) as usize;
278                (10, len)
279            }
280            _ => unreachable!(),
281        };
282
283        // Check frame size limit
284        if payload_len > self.max_frame_size {
285            return Err(std::io::Error::new(
286                std::io::ErrorKind::InvalidData,
287                format!(
288                    "Frame size {} exceeds maximum {}",
289                    payload_len, self.max_frame_size
290                ),
291            ));
292        }
293
294        // Calculate total frame size
295        let mask_size = if masked { 4 } else { 0 };
296        let total_size = header_size + mask_size + payload_len;
297
298        // Wait for complete frame
299        if src.len() < total_size {
300            return Ok(None);
301        }
302
303        // Extract masking key if present
304        let mask = if masked {
305            let mask_start = header_size;
306            Some([
307                src[mask_start],
308                src[mask_start + 1],
309                src[mask_start + 2],
310                src[mask_start + 3],
311            ])
312        } else {
313            None
314        };
315
316        // Extract and unmask payload
317        let payload_start = header_size + mask_size;
318        let mut payload = src[payload_start..payload_start + payload_len].to_vec();
319        if let Some(m) = mask {
320            Self::apply_mask(&mut payload, m);
321        }
322
323        Ok(Some((
324            WebSocketFrame {
325                fin,
326                opcode,
327                mask,
328                payload,
329            },
330            total_size,
331        )))
332    }
333
334    /// Encode a frame to bytes.
335    ///
336    /// If `masked` is true, the frame will be masked with a random key.
337    pub fn encode_frame(
338        &self,
339        frame: &WebSocketFrame,
340        masked: bool,
341    ) -> Result<Vec<u8>, std::io::Error> {
342        let payload_len = frame.payload.len();
343
344        // Check frame size
345        if payload_len > self.max_frame_size {
346            return Err(std::io::Error::new(
347                std::io::ErrorKind::InvalidData,
348                format!(
349                    "Frame size {} exceeds maximum {}",
350                    payload_len, self.max_frame_size
351                ),
352            ));
353        }
354
355        // Calculate sizes
356        let header_len: usize = match payload_len {
357            0..=125 => 2,
358            126..=65535 => 4,
359            _ => 10,
360        };
361        let mask_len = if masked { 4 } else { 0 };
362        let total_len = header_len + mask_len + payload_len;
363
364        let mut dst = Vec::with_capacity(total_len);
365
366        // First byte: FIN + RSV (0) + opcode
367        let first_byte = (if frame.fin { 0x80 } else { 0x00 }) | (frame.opcode.as_u8() & 0x0F);
368        dst.push(first_byte);
369
370        // Second byte: MASK + payload length
371        let mask_bit = if masked { 0x80 } else { 0x00 };
372        match payload_len {
373            0..=125 => {
374                dst.push(mask_bit | (payload_len as u8));
375            }
376            126..=65535 => {
377                dst.push(mask_bit | 126);
378                dst.extend_from_slice(&(payload_len as u16).to_be_bytes());
379            }
380            _ => {
381                dst.push(mask_bit | 127);
382                dst.extend_from_slice(&(payload_len as u64).to_be_bytes());
383            }
384        }
385
386        // Masking key and payload
387        if masked {
388            let mask: [u8; 4] = rand::random();
389            dst.extend_from_slice(&mask);
390            let mut masked_payload = frame.payload.clone();
391            Self::apply_mask(&mut masked_payload, mask);
392            dst.extend_from_slice(&masked_payload);
393        } else {
394            dst.extend_from_slice(&frame.payload);
395        }
396
397        Ok(dst)
398    }
399}
400
401impl Decoder for WebSocketCodec {
402    type Item = WebSocketFrame;
403    type Error = io::Error;
404
405    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
406        // Need at least 2 bytes for the header
407        if src.len() < 2 {
408            return Ok(None);
409        }
410
411        // Parse first two bytes
412        let first_byte = src[0];
413        let second_byte = src[1];
414
415        let fin = (first_byte & 0x80) != 0;
416        let rsv = (first_byte & 0x70) >> 4;
417        let opcode = Opcode::from_u8(first_byte & 0x0F);
418        let masked = (second_byte & 0x80) != 0;
419        let payload_len_byte = second_byte & 0x7F;
420
421        // Check RSV bits (must be 0 unless extension negotiated)
422        if rsv != 0 {
423            return Err(io::Error::new(
424                io::ErrorKind::InvalidData,
425                "Non-zero RSV bits without extension",
426            ));
427        }
428
429        // Check masking requirement
430        if self.expect_masked && !masked {
431            return Err(io::Error::new(
432                io::ErrorKind::InvalidData,
433                "Expected masked frame from client",
434            ));
435        }
436        if !self.expect_masked && masked {
437            return Err(io::Error::new(
438                io::ErrorKind::InvalidData,
439                "Unexpected masked frame from server",
440            ));
441        }
442
443        // Calculate header size and payload length
444        let (header_size, payload_len) = match payload_len_byte {
445            0..=125 => (2, payload_len_byte as usize),
446            126 => {
447                if src.len() < 4 {
448                    return Ok(None);
449                }
450                let len = u16::from_be_bytes([src[2], src[3]]) as usize;
451                (4, len)
452            }
453            127 => {
454                if src.len() < 10 {
455                    return Ok(None);
456                }
457                let len = u64::from_be_bytes([
458                    src[2], src[3], src[4], src[5], src[6], src[7], src[8], src[9],
459                ]) as usize;
460                (10, len)
461            }
462            _ => unreachable!(),
463        };
464
465        // Check frame size limit
466        if payload_len > self.max_frame_size {
467            return Err(io::Error::new(
468                io::ErrorKind::InvalidData,
469                format!(
470                    "Frame size {} exceeds maximum {}",
471                    payload_len, self.max_frame_size
472                ),
473            ));
474        }
475
476        // Calculate total frame size
477        let mask_size = if masked { 4 } else { 0 };
478        let total_size = header_size + mask_size + payload_len;
479
480        // Wait for complete frame
481        if src.len() < total_size {
482            src.reserve(total_size - src.len());
483            return Ok(None);
484        }
485
486        // Extract masking key if present
487        let mask = if masked {
488            let mask_start = header_size;
489            Some([
490                src[mask_start],
491                src[mask_start + 1],
492                src[mask_start + 2],
493                src[mask_start + 3],
494            ])
495        } else {
496            None
497        };
498
499        // Extract and unmask payload
500        let payload_start = header_size + mask_size;
501        let mut payload = src[payload_start..payload_start + payload_len].to_vec();
502        if let Some(m) = mask {
503            Self::apply_mask(&mut payload, m);
504        }
505
506        // Consume the frame from the buffer
507        src.advance(total_size);
508
509        trace!(
510            fin = fin,
511            opcode = ?opcode,
512            masked = masked,
513            payload_len = payload_len,
514            "Decoded WebSocket frame"
515        );
516
517        Ok(Some(WebSocketFrame {
518            fin,
519            opcode,
520            mask,
521            payload,
522        }))
523    }
524}
525
526impl Encoder<WebSocketFrame> for WebSocketCodec {
527    type Error = io::Error;
528
529    fn encode(&mut self, frame: WebSocketFrame, dst: &mut BytesMut) -> Result<(), Self::Error> {
530        let payload_len = frame.payload.len();
531
532        // Check frame size
533        if payload_len > self.max_frame_size {
534            return Err(io::Error::new(
535                io::ErrorKind::InvalidData,
536                format!(
537                    "Frame size {} exceeds maximum {}",
538                    payload_len, self.max_frame_size
539                ),
540            ));
541        }
542
543        // Calculate header size
544        let (header_len, extended_len_bytes): (usize, usize) = match payload_len {
545            0..=125 => (2, 0),
546            126..=65535 => (4, 2),
547            _ => (10, 8),
548        };
549
550        let should_mask = self.mask_outgoing;
551        let mask_len = if should_mask { 4 } else { 0 };
552        let total_len = header_len + mask_len + payload_len;
553
554        dst.reserve(total_len);
555
556        // First byte: FIN + RSV (0) + opcode
557        let first_byte = (if frame.fin { 0x80 } else { 0x00 }) | (frame.opcode.as_u8() & 0x0F);
558        dst.put_u8(first_byte);
559
560        // Second byte: MASK + payload length
561        let mask_bit = if should_mask { 0x80 } else { 0x00 };
562        match payload_len {
563            0..=125 => {
564                dst.put_u8(mask_bit | (payload_len as u8));
565            }
566            126..=65535 => {
567                dst.put_u8(mask_bit | 126);
568                dst.put_u16(payload_len as u16);
569            }
570            _ => {
571                dst.put_u8(mask_bit | 127);
572                dst.put_u64(payload_len as u64);
573            }
574        }
575
576        // Masking key and payload
577        if should_mask {
578            // Generate random mask
579            let mask: [u8; 4] = rand::random();
580            dst.put_slice(&mask);
581
582            // Mask and write payload
583            let mut masked_payload = frame.payload;
584            Self::apply_mask(&mut masked_payload, mask);
585            dst.put_slice(&masked_payload);
586        } else {
587            dst.put_slice(&frame.payload);
588        }
589
590        trace!(
591            fin = frame.fin,
592            opcode = ?frame.opcode,
593            masked = should_mask,
594            payload_len = payload_len,
595            "Encoded WebSocket frame"
596        );
597
598        Ok(())
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn test_opcode_round_trip() {
608        for i in 0..=15 {
609            let opcode = Opcode::from_u8(i);
610            if !matches!(opcode, Opcode::Reserved(_)) {
611                assert_eq!(opcode.as_u8(), i);
612            }
613        }
614    }
615
616    #[test]
617    fn test_decode_unmasked_text_frame() {
618        let mut codec = WebSocketCodec {
619            max_frame_size: DEFAULT_MAX_FRAME_SIZE,
620            expect_masked: false,
621            mask_outgoing: false,
622        };
623
624        // FIN=1, opcode=text, no mask, len=5, payload="Hello"
625        let data = [0x81, 0x05, b'H', b'e', b'l', b'l', b'o'];
626        let mut buf = BytesMut::from(&data[..]);
627
628        let frame = codec.decode(&mut buf).unwrap().unwrap();
629        assert!(frame.fin);
630        assert_eq!(frame.opcode, Opcode::Text);
631        assert_eq!(frame.payload, b"Hello");
632        assert!(buf.is_empty());
633    }
634
635    #[test]
636    fn test_decode_masked_text_frame() {
637        let mut codec = WebSocketCodec::server();
638
639        // FIN=1, opcode=text, masked, len=5, mask + masked payload
640        let mask = [0x37, 0xfa, 0x21, 0x3d];
641        let payload = b"Hello";
642        let mut masked_payload = payload.to_vec();
643        WebSocketCodec::apply_mask(&mut masked_payload, mask);
644
645        let mut data = vec![0x81, 0x85]; // FIN + opcode, mask bit + len
646        data.extend_from_slice(&mask);
647        data.extend_from_slice(&masked_payload);
648
649        let mut buf = BytesMut::from(&data[..]);
650        let frame = codec.decode(&mut buf).unwrap().unwrap();
651
652        assert!(frame.fin);
653        assert_eq!(frame.opcode, Opcode::Text);
654        assert_eq!(frame.payload, b"Hello");
655    }
656
657    #[test]
658    fn test_decode_close_frame() {
659        let mut codec = WebSocketCodec {
660            max_frame_size: DEFAULT_MAX_FRAME_SIZE,
661            expect_masked: false,
662            mask_outgoing: false,
663        };
664
665        // FIN=1, opcode=close, no mask, len=2, code=1000
666        let data = [0x88, 0x02, 0x03, 0xE8];
667        let mut buf = BytesMut::from(&data[..]);
668
669        let frame = codec.decode(&mut buf).unwrap().unwrap();
670        assert!(frame.fin);
671        assert_eq!(frame.opcode, Opcode::Close);
672        let (code, reason) = frame.close_code_and_reason().unwrap();
673        assert_eq!(code, 1000);
674        assert!(reason.is_empty());
675    }
676
677    #[test]
678    fn test_encode_text_frame() {
679        let mut codec = WebSocketCodec {
680            max_frame_size: DEFAULT_MAX_FRAME_SIZE,
681            expect_masked: false,
682            mask_outgoing: false,
683        };
684
685        let frame = WebSocketFrame::new(Opcode::Text, b"Hello".to_vec());
686        let mut buf = BytesMut::new();
687        codec.encode(frame, &mut buf).unwrap();
688
689        assert_eq!(&buf[..], &[0x81, 0x05, b'H', b'e', b'l', b'l', b'o']);
690    }
691
692    #[test]
693    fn test_frame_size_limit() {
694        let mut codec = WebSocketCodec {
695            max_frame_size: 10,
696            expect_masked: false,
697            mask_outgoing: false,
698        };
699
700        // Try to decode a frame claiming 100 bytes
701        let data = [0x81, 0x64]; // len=100
702        let mut buf = BytesMut::from(&data[..]);
703
704        let result = codec.decode(&mut buf);
705        assert!(result.is_err());
706    }
707
708    #[test]
709    fn test_close_frame_construction() {
710        let frame = WebSocketFrame::close(1001, "Going away");
711        assert_eq!(frame.opcode, Opcode::Close);
712
713        let (code, reason) = frame.close_code_and_reason().unwrap();
714        assert_eq!(code, 1001);
715        assert_eq!(reason, "Going away");
716    }
717}