1use bytes::{Buf, BufMut, BytesMut};
6use std::io;
7use tokio_util::codec::{Decoder, Encoder};
8use tracing::trace;
9
10pub const DEFAULT_MAX_FRAME_SIZE: usize = 1024 * 1024;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum Opcode {
16 Continuation,
18 Text,
20 Binary,
22 Close,
24 Ping,
26 Pong,
28 Reserved(u8),
30}
31
32impl Opcode {
33 pub fn from_u8(value: u8) -> Self {
35 match value & 0x0F {
36 0x0 => Self::Continuation,
37 0x1 => Self::Text,
38 0x2 => Self::Binary,
39 0x8 => Self::Close,
40 0x9 => Self::Ping,
41 0xA => Self::Pong,
42 other => Self::Reserved(other),
43 }
44 }
45
46 pub fn as_u8(&self) -> u8 {
48 match self {
49 Self::Continuation => 0x0,
50 Self::Text => 0x1,
51 Self::Binary => 0x2,
52 Self::Close => 0x8,
53 Self::Ping => 0x9,
54 Self::Pong => 0xA,
55 Self::Reserved(v) => *v,
56 }
57 }
58
59 pub fn as_str(&self) -> &'static str {
61 match self {
62 Self::Continuation => "continuation",
63 Self::Text => "text",
64 Self::Binary => "binary",
65 Self::Close => "close",
66 Self::Ping => "ping",
67 Self::Pong => "pong",
68 Self::Reserved(_) => "reserved",
69 }
70 }
71
72 pub fn is_control(&self) -> bool {
74 matches!(self, Self::Close | Self::Ping | Self::Pong)
75 }
76
77 pub fn is_data(&self) -> bool {
79 matches!(self, Self::Continuation | Self::Text | Self::Binary)
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct WebSocketFrame {
86 pub fin: bool,
88 pub opcode: Opcode,
90 pub mask: Option<[u8; 4]>,
92 pub payload: Vec<u8>,
94}
95
96impl WebSocketFrame {
97 pub fn new(opcode: Opcode, payload: Vec<u8>) -> Self {
99 Self {
100 fin: true,
101 opcode,
102 mask: None,
103 payload,
104 }
105 }
106
107 pub fn close(code: u16, reason: &str) -> Self {
109 let mut payload = Vec::with_capacity(2 + reason.len());
110 payload.extend_from_slice(&code.to_be_bytes());
111 payload.extend_from_slice(reason.as_bytes());
112 Self {
113 fin: true,
114 opcode: Opcode::Close,
115 mask: None,
116 payload,
117 }
118 }
119
120 pub fn ping(data: Vec<u8>) -> Self {
122 Self {
123 fin: true,
124 opcode: Opcode::Ping,
125 mask: None,
126 payload: data,
127 }
128 }
129
130 pub fn pong(data: Vec<u8>) -> Self {
132 Self {
133 fin: true,
134 opcode: Opcode::Pong,
135 mask: None,
136 payload: data,
137 }
138 }
139
140 pub fn with_mask(mut self, mask: [u8; 4]) -> Self {
142 self.mask = Some(mask);
143 self
144 }
145
146 pub fn with_fin(mut self, fin: bool) -> Self {
148 self.fin = fin;
149 self
150 }
151
152 pub fn close_code_and_reason(&self) -> Option<(u16, String)> {
154 if self.opcode != Opcode::Close || self.payload.len() < 2 {
155 return None;
156 }
157 let code = u16::from_be_bytes([self.payload[0], self.payload[1]]);
158 let reason = if self.payload.len() > 2 {
159 String::from_utf8_lossy(&self.payload[2..]).to_string()
160 } else {
161 String::new()
162 };
163 Some((code, reason))
164 }
165}
166
167pub struct WebSocketCodec {
171 max_frame_size: usize,
173 expect_masked: bool,
175 mask_outgoing: bool,
177}
178
179impl WebSocketCodec {
180 pub fn new(max_frame_size: usize) -> Self {
186 Self {
187 max_frame_size,
188 expect_masked: false, mask_outgoing: false,
190 }
191 }
192
193 pub fn server() -> Self {
198 Self {
199 max_frame_size: DEFAULT_MAX_FRAME_SIZE,
200 expect_masked: true,
201 mask_outgoing: false,
202 }
203 }
204
205 pub fn client() -> Self {
210 Self {
211 max_frame_size: DEFAULT_MAX_FRAME_SIZE,
212 expect_masked: false,
213 mask_outgoing: true,
214 }
215 }
216
217 pub fn with_max_frame_size(mut self, size: usize) -> Self {
219 self.max_frame_size = size;
220 self
221 }
222
223 fn apply_mask(data: &mut [u8], mask: [u8; 4]) {
225 for (i, byte) in data.iter_mut().enumerate() {
226 *byte ^= mask[i % 4];
227 }
228 }
229
230 pub fn decode_frame(
235 &self,
236 src: &BytesMut,
237 ) -> Result<Option<(WebSocketFrame, usize)>, std::io::Error> {
238 if src.len() < 2 {
240 return Ok(None);
241 }
242
243 let first_byte = src[0];
245 let second_byte = src[1];
246
247 let fin = (first_byte & 0x80) != 0;
248 let rsv = (first_byte & 0x70) >> 4;
249 let opcode = Opcode::from_u8(first_byte & 0x0F);
250 let masked = (second_byte & 0x80) != 0;
251 let payload_len_byte = second_byte & 0x7F;
252
253 if rsv != 0 {
255 return Err(std::io::Error::new(
256 std::io::ErrorKind::InvalidData,
257 "Non-zero RSV bits without extension",
258 ));
259 }
260
261 let (header_size, payload_len) = match payload_len_byte {
263 0..=125 => (2, payload_len_byte as usize),
264 126 => {
265 if src.len() < 4 {
266 return Ok(None);
267 }
268 let len = u16::from_be_bytes([src[2], src[3]]) as usize;
269 (4, len)
270 }
271 127 => {
272 if src.len() < 10 {
273 return Ok(None);
274 }
275 let len = u64::from_be_bytes([
276 src[2], src[3], src[4], src[5], src[6], src[7], src[8], src[9],
277 ]) as usize;
278 (10, len)
279 }
280 _ => unreachable!(),
281 };
282
283 if payload_len > self.max_frame_size {
285 return Err(std::io::Error::new(
286 std::io::ErrorKind::InvalidData,
287 format!(
288 "Frame size {} exceeds maximum {}",
289 payload_len, self.max_frame_size
290 ),
291 ));
292 }
293
294 let mask_size = if masked { 4 } else { 0 };
296 let total_size = header_size + mask_size + payload_len;
297
298 if src.len() < total_size {
300 return Ok(None);
301 }
302
303 let mask = if masked {
305 let mask_start = header_size;
306 Some([
307 src[mask_start],
308 src[mask_start + 1],
309 src[mask_start + 2],
310 src[mask_start + 3],
311 ])
312 } else {
313 None
314 };
315
316 let payload_start = header_size + mask_size;
318 let mut payload = src[payload_start..payload_start + payload_len].to_vec();
319 if let Some(m) = mask {
320 Self::apply_mask(&mut payload, m);
321 }
322
323 Ok(Some((
324 WebSocketFrame {
325 fin,
326 opcode,
327 mask,
328 payload,
329 },
330 total_size,
331 )))
332 }
333
334 pub fn encode_frame(
338 &self,
339 frame: &WebSocketFrame,
340 masked: bool,
341 ) -> Result<Vec<u8>, std::io::Error> {
342 let payload_len = frame.payload.len();
343
344 if payload_len > self.max_frame_size {
346 return Err(std::io::Error::new(
347 std::io::ErrorKind::InvalidData,
348 format!(
349 "Frame size {} exceeds maximum {}",
350 payload_len, self.max_frame_size
351 ),
352 ));
353 }
354
355 let header_len: usize = match payload_len {
357 0..=125 => 2,
358 126..=65535 => 4,
359 _ => 10,
360 };
361 let mask_len = if masked { 4 } else { 0 };
362 let total_len = header_len + mask_len + payload_len;
363
364 let mut dst = Vec::with_capacity(total_len);
365
366 let first_byte = (if frame.fin { 0x80 } else { 0x00 }) | (frame.opcode.as_u8() & 0x0F);
368 dst.push(first_byte);
369
370 let mask_bit = if masked { 0x80 } else { 0x00 };
372 match payload_len {
373 0..=125 => {
374 dst.push(mask_bit | (payload_len as u8));
375 }
376 126..=65535 => {
377 dst.push(mask_bit | 126);
378 dst.extend_from_slice(&(payload_len as u16).to_be_bytes());
379 }
380 _ => {
381 dst.push(mask_bit | 127);
382 dst.extend_from_slice(&(payload_len as u64).to_be_bytes());
383 }
384 }
385
386 if masked {
388 let mask: [u8; 4] = rand::random();
389 dst.extend_from_slice(&mask);
390 let mut masked_payload = frame.payload.clone();
391 Self::apply_mask(&mut masked_payload, mask);
392 dst.extend_from_slice(&masked_payload);
393 } else {
394 dst.extend_from_slice(&frame.payload);
395 }
396
397 Ok(dst)
398 }
399}
400
401impl Decoder for WebSocketCodec {
402 type Item = WebSocketFrame;
403 type Error = io::Error;
404
405 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
406 if src.len() < 2 {
408 return Ok(None);
409 }
410
411 let first_byte = src[0];
413 let second_byte = src[1];
414
415 let fin = (first_byte & 0x80) != 0;
416 let rsv = (first_byte & 0x70) >> 4;
417 let opcode = Opcode::from_u8(first_byte & 0x0F);
418 let masked = (second_byte & 0x80) != 0;
419 let payload_len_byte = second_byte & 0x7F;
420
421 if rsv != 0 {
423 return Err(io::Error::new(
424 io::ErrorKind::InvalidData,
425 "Non-zero RSV bits without extension",
426 ));
427 }
428
429 if self.expect_masked && !masked {
431 return Err(io::Error::new(
432 io::ErrorKind::InvalidData,
433 "Expected masked frame from client",
434 ));
435 }
436 if !self.expect_masked && masked {
437 return Err(io::Error::new(
438 io::ErrorKind::InvalidData,
439 "Unexpected masked frame from server",
440 ));
441 }
442
443 let (header_size, payload_len) = match payload_len_byte {
445 0..=125 => (2, payload_len_byte as usize),
446 126 => {
447 if src.len() < 4 {
448 return Ok(None);
449 }
450 let len = u16::from_be_bytes([src[2], src[3]]) as usize;
451 (4, len)
452 }
453 127 => {
454 if src.len() < 10 {
455 return Ok(None);
456 }
457 let len = u64::from_be_bytes([
458 src[2], src[3], src[4], src[5], src[6], src[7], src[8], src[9],
459 ]) as usize;
460 (10, len)
461 }
462 _ => unreachable!(),
463 };
464
465 if payload_len > self.max_frame_size {
467 return Err(io::Error::new(
468 io::ErrorKind::InvalidData,
469 format!(
470 "Frame size {} exceeds maximum {}",
471 payload_len, self.max_frame_size
472 ),
473 ));
474 }
475
476 let mask_size = if masked { 4 } else { 0 };
478 let total_size = header_size + mask_size + payload_len;
479
480 if src.len() < total_size {
482 src.reserve(total_size - src.len());
483 return Ok(None);
484 }
485
486 let mask = if masked {
488 let mask_start = header_size;
489 Some([
490 src[mask_start],
491 src[mask_start + 1],
492 src[mask_start + 2],
493 src[mask_start + 3],
494 ])
495 } else {
496 None
497 };
498
499 let payload_start = header_size + mask_size;
501 let mut payload = src[payload_start..payload_start + payload_len].to_vec();
502 if let Some(m) = mask {
503 Self::apply_mask(&mut payload, m);
504 }
505
506 src.advance(total_size);
508
509 trace!(
510 fin = fin,
511 opcode = ?opcode,
512 masked = masked,
513 payload_len = payload_len,
514 "Decoded WebSocket frame"
515 );
516
517 Ok(Some(WebSocketFrame {
518 fin,
519 opcode,
520 mask,
521 payload,
522 }))
523 }
524}
525
526impl Encoder<WebSocketFrame> for WebSocketCodec {
527 type Error = io::Error;
528
529 fn encode(&mut self, frame: WebSocketFrame, dst: &mut BytesMut) -> Result<(), Self::Error> {
530 let payload_len = frame.payload.len();
531
532 if payload_len > self.max_frame_size {
534 return Err(io::Error::new(
535 io::ErrorKind::InvalidData,
536 format!(
537 "Frame size {} exceeds maximum {}",
538 payload_len, self.max_frame_size
539 ),
540 ));
541 }
542
543 let (header_len, extended_len_bytes): (usize, usize) = match payload_len {
545 0..=125 => (2, 0),
546 126..=65535 => (4, 2),
547 _ => (10, 8),
548 };
549
550 let should_mask = self.mask_outgoing;
551 let mask_len = if should_mask { 4 } else { 0 };
552 let total_len = header_len + mask_len + payload_len;
553
554 dst.reserve(total_len);
555
556 let first_byte = (if frame.fin { 0x80 } else { 0x00 }) | (frame.opcode.as_u8() & 0x0F);
558 dst.put_u8(first_byte);
559
560 let mask_bit = if should_mask { 0x80 } else { 0x00 };
562 match payload_len {
563 0..=125 => {
564 dst.put_u8(mask_bit | (payload_len as u8));
565 }
566 126..=65535 => {
567 dst.put_u8(mask_bit | 126);
568 dst.put_u16(payload_len as u16);
569 }
570 _ => {
571 dst.put_u8(mask_bit | 127);
572 dst.put_u64(payload_len as u64);
573 }
574 }
575
576 if should_mask {
578 let mask: [u8; 4] = rand::random();
580 dst.put_slice(&mask);
581
582 let mut masked_payload = frame.payload;
584 Self::apply_mask(&mut masked_payload, mask);
585 dst.put_slice(&masked_payload);
586 } else {
587 dst.put_slice(&frame.payload);
588 }
589
590 trace!(
591 fin = frame.fin,
592 opcode = ?frame.opcode,
593 masked = should_mask,
594 payload_len = payload_len,
595 "Encoded WebSocket frame"
596 );
597
598 Ok(())
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_opcode_round_trip() {
608 for i in 0..=15 {
609 let opcode = Opcode::from_u8(i);
610 if !matches!(opcode, Opcode::Reserved(_)) {
611 assert_eq!(opcode.as_u8(), i);
612 }
613 }
614 }
615
616 #[test]
617 fn test_decode_unmasked_text_frame() {
618 let mut codec = WebSocketCodec {
619 max_frame_size: DEFAULT_MAX_FRAME_SIZE,
620 expect_masked: false,
621 mask_outgoing: false,
622 };
623
624 let data = [0x81, 0x05, b'H', b'e', b'l', b'l', b'o'];
626 let mut buf = BytesMut::from(&data[..]);
627
628 let frame = codec.decode(&mut buf).unwrap().unwrap();
629 assert!(frame.fin);
630 assert_eq!(frame.opcode, Opcode::Text);
631 assert_eq!(frame.payload, b"Hello");
632 assert!(buf.is_empty());
633 }
634
635 #[test]
636 fn test_decode_masked_text_frame() {
637 let mut codec = WebSocketCodec::server();
638
639 let mask = [0x37, 0xfa, 0x21, 0x3d];
641 let payload = b"Hello";
642 let mut masked_payload = payload.to_vec();
643 WebSocketCodec::apply_mask(&mut masked_payload, mask);
644
645 let mut data = vec![0x81, 0x85]; data.extend_from_slice(&mask);
647 data.extend_from_slice(&masked_payload);
648
649 let mut buf = BytesMut::from(&data[..]);
650 let frame = codec.decode(&mut buf).unwrap().unwrap();
651
652 assert!(frame.fin);
653 assert_eq!(frame.opcode, Opcode::Text);
654 assert_eq!(frame.payload, b"Hello");
655 }
656
657 #[test]
658 fn test_decode_close_frame() {
659 let mut codec = WebSocketCodec {
660 max_frame_size: DEFAULT_MAX_FRAME_SIZE,
661 expect_masked: false,
662 mask_outgoing: false,
663 };
664
665 let data = [0x88, 0x02, 0x03, 0xE8];
667 let mut buf = BytesMut::from(&data[..]);
668
669 let frame = codec.decode(&mut buf).unwrap().unwrap();
670 assert!(frame.fin);
671 assert_eq!(frame.opcode, Opcode::Close);
672 let (code, reason) = frame.close_code_and_reason().unwrap();
673 assert_eq!(code, 1000);
674 assert!(reason.is_empty());
675 }
676
677 #[test]
678 fn test_encode_text_frame() {
679 let mut codec = WebSocketCodec {
680 max_frame_size: DEFAULT_MAX_FRAME_SIZE,
681 expect_masked: false,
682 mask_outgoing: false,
683 };
684
685 let frame = WebSocketFrame::new(Opcode::Text, b"Hello".to_vec());
686 let mut buf = BytesMut::new();
687 codec.encode(frame, &mut buf).unwrap();
688
689 assert_eq!(&buf[..], &[0x81, 0x05, b'H', b'e', b'l', b'l', b'o']);
690 }
691
692 #[test]
693 fn test_frame_size_limit() {
694 let mut codec = WebSocketCodec {
695 max_frame_size: 10,
696 expect_masked: false,
697 mask_outgoing: false,
698 };
699
700 let data = [0x81, 0x64]; let mut buf = BytesMut::from(&data[..]);
703
704 let result = codec.decode(&mut buf);
705 assert!(result.is_err());
706 }
707
708 #[test]
709 fn test_close_frame_construction() {
710 let frame = WebSocketFrame::close(1001, "Going away");
711 assert_eq!(frame.opcode, Opcode::Close);
712
713 let (code, reason) = frame.close_code_and_reason().unwrap();
714 assert_eq!(code, 1001);
715 assert_eq!(reason, "Going away");
716 }
717}