1use std::{fmt, mem::replace, num::NonZeroU16, ops::Deref};
3
4use bytes::{BufMut, Bytes, BytesMut};
5
6use super::error::ProtocolError;
7use crate::utf8;
8
9#[derive(Debug, PartialEq, Eq, Clone, Copy)]
14pub(super) enum OpCode {
15 Continuation,
18 Text,
20 Binary,
22 Close,
24 Ping,
26 Pong,
28}
29
30impl OpCode {
31 pub(super) fn is_control(self) -> bool {
33 matches!(self, Self::Close | Self::Ping | Self::Pong)
34 }
35}
36
37impl TryFrom<u8> for OpCode {
38 type Error = ProtocolError;
39
40 fn try_from(value: u8) -> Result<Self, Self::Error> {
41 match value {
42 0 => Ok(Self::Continuation),
43 1 => Ok(Self::Text),
44 2 => Ok(Self::Binary),
45 8 => Ok(Self::Close),
46 9 => Ok(Self::Ping),
47 10 => Ok(Self::Pong),
48 _ => Err(ProtocolError::InvalidOpcode),
49 }
50 }
51}
52
53impl From<OpCode> for u8 {
54 fn from(value: OpCode) -> Self {
55 match value {
56 OpCode::Continuation => 0,
57 OpCode::Text => 1,
58 OpCode::Binary => 2,
59 OpCode::Close => 8,
60 OpCode::Ping => 9,
61 OpCode::Pong => 10,
62 }
63 }
64}
65
66#[derive(Clone, Copy, Debug, Eq, PartialEq)]
68pub struct CloseCode(NonZeroU16);
69
70#[rustfmt::skip]
72impl CloseCode {
73 pub const NORMAL_CLOSURE: Self = Self::constant(1000);
76 pub const GOING_AWAY: Self = Self::constant(1001);
79 pub const PROTOCOL_ERROR: Self = Self::constant(1002);
81 pub const UNSUPPORTED_DATA: Self = Self::constant(1003);
84 pub const NO_STATUS_RECEIVED: Self = Self::constant(1005);
86 pub const INVALID_FRAME_PAYLOAD_DATA: Self = Self::constant(1007);
89 pub const POLICY_VIOLATION: Self = Self::constant(1008);
92 pub const MESSAGE_TOO_BIG: Self = Self::constant(1009);
95 pub const MANDATORY_EXTENSION: Self = Self::constant(1010);
99 pub const INTERNAL_SERVER_ERROR: Self = Self::constant(1011);
102 pub const SERVICE_RESTART: Self = Self::constant(1012);
105 pub const SERVICE_OVERLOAD: Self = Self::constant(1013);
109 pub const BAD_GATEWAY: Self = Self::constant(1014);
113}
114
115impl CloseCode {
116 const fn try_from_u16(code: u16) -> Option<Self> {
120 match code {
121 1000..=1015 | 3000..=4999 => {
122 match NonZeroU16::new(code) {
125 Some(code) => Some(Self(code)),
126 None => unreachable!(),
127 }
128 }
129 0..=999 | 1016..=2999 | 5000..=u16::MAX => None,
130 }
131 }
132
133 const fn constant(code: u16) -> Self {
137 match Self::try_from_u16(code) {
140 Some(code) => code,
141 None => unreachable!(),
142 }
143 }
144
145 #[must_use]
147 pub fn is_reserved(self) -> bool {
148 match self.0.get() {
149 1004 | 1005 | 1006 | 1015 => true,
150 1000..=4999 => false,
151 0..=999 | 5000..=u16::MAX => {
153 debug_assert!(false, "unexpected CloseCode");
154 false
155 }
156 }
157 }
158}
159
160impl From<CloseCode> for u16 {
161 fn from(value: CloseCode) -> Self {
162 value.0.get()
163 }
164}
165
166impl TryFrom<u16> for CloseCode {
167 type Error = ProtocolError;
168
169 fn try_from(value: u16) -> Result<Self, Self::Error> {
170 Self::try_from_u16(value).ok_or(ProtocolError::InvalidCloseCode)
171 }
172}
173
174#[derive(Clone)]
184pub struct Payload {
185 data: Bytes,
187 utf8_validated: bool,
189}
190
191impl Payload {
192 const fn from_static(bytes: &'static [u8]) -> Self {
194 Self {
195 data: Bytes::from_static(bytes),
196 utf8_validated: false,
197 }
198 }
199
200 pub(super) fn set_utf8_validated(&mut self, value: bool) {
202 self.utf8_validated = value;
203 }
204
205 pub(super) fn truncate(&mut self, len: usize) {
208 self.data.truncate(len);
209 }
210
211 fn split_to(&mut self, at: usize) -> Self {
213 self.utf8_validated = false;
217 Self {
218 data: self.data.split_to(at),
219 utf8_validated: false,
220 }
221 }
222}
223
224impl Deref for Payload {
225 type Target = [u8];
226
227 fn deref(&self) -> &Self::Target {
228 &self.data
229 }
230}
231
232impl fmt::Debug for Payload {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 f.debug_tuple("Payload").field(&self.data).finish()
235 }
236}
237
238impl From<Bytes> for Payload {
239 fn from(value: Bytes) -> Self {
240 Self {
241 data: value,
242 utf8_validated: false,
243 }
244 }
245}
246
247impl From<BytesMut> for Payload {
248 fn from(value: BytesMut) -> Self {
249 Self {
250 data: value.freeze(),
251 utf8_validated: false,
252 }
253 }
254}
255
256impl From<Payload> for Bytes {
257 fn from(value: Payload) -> Self {
258 value.data
259 }
260}
261
262impl From<Payload> for BytesMut {
263 fn from(value: Payload) -> Self {
264 value.data.into()
265 }
266}
267
268impl From<Vec<u8>> for Payload {
269 fn from(value: Vec<u8>) -> Self {
270 Self {
274 data: BytesMut::from_iter(value).freeze(),
275 utf8_validated: false,
276 }
277 }
278}
279
280impl From<String> for Payload {
281 fn from(value: String) -> Self {
282 Self {
284 data: BytesMut::from_iter(value.into_bytes()).freeze(),
285 utf8_validated: true,
286 }
287 }
288}
289
290impl From<&'static [u8]> for Payload {
291 fn from(value: &'static [u8]) -> Self {
292 Self {
293 data: Bytes::from_static(value),
294 utf8_validated: false,
295 }
296 }
297}
298
299impl From<&'static str> for Payload {
300 fn from(value: &'static str) -> Self {
301 Self {
302 data: Bytes::from_static(value.as_bytes()),
303 utf8_validated: true,
304 }
305 }
306}
307
308#[derive(Debug, Clone)]
314pub struct Message {
315 pub(super) opcode: OpCode,
317 pub(super) payload: Payload,
319}
320
321impl Message {
322 #[must_use]
324 pub fn text<P: Into<Payload>>(payload: P) -> Self {
325 Self {
326 opcode: OpCode::Text,
327 payload: payload.into(),
328 }
329 }
330
331 #[must_use]
333 pub fn binary<P: Into<Payload>>(payload: P) -> Self {
334 Self {
335 opcode: OpCode::Binary,
336 payload: payload.into(),
337 }
338 }
339
340 #[must_use]
348 #[track_caller]
349 pub fn close(code: Option<CloseCode>, reason: &str) -> Self {
350 let mut payload = BytesMut::with_capacity((2 + reason.len()) * usize::from(code.is_some()));
351
352 if let Some(code) = code {
353 assert!(!code.is_reserved());
354 payload.put_u16(code.into());
355
356 assert!(reason.len() <= 123);
357 payload.extend_from_slice(reason.as_bytes());
358 }
359
360 Self {
361 opcode: OpCode::Close,
362 payload: payload.into(),
363 }
364 }
365
366 #[must_use]
371 #[track_caller]
372 pub fn ping<P: Into<Payload>>(payload: P) -> Self {
373 let payload = payload.into();
374 assert!(payload.len() <= 125);
375 Self {
376 opcode: OpCode::Ping,
377 payload,
378 }
379 }
380
381 #[must_use]
386 #[track_caller]
387 pub fn pong<P: Into<Payload>>(payload: P) -> Self {
388 let payload = payload.into();
389 assert!(payload.len() <= 125);
390 Self {
391 opcode: OpCode::Pong,
392 payload,
393 }
394 }
395
396 #[must_use]
398 pub fn is_text(&self) -> bool {
399 self.opcode == OpCode::Text
400 }
401
402 #[must_use]
404 pub fn is_binary(&self) -> bool {
405 self.opcode == OpCode::Binary
406 }
407
408 #[must_use]
410 pub fn is_close(&self) -> bool {
411 self.opcode == OpCode::Close
412 }
413
414 #[must_use]
416 pub fn is_ping(&self) -> bool {
417 self.opcode == OpCode::Ping
418 }
419
420 #[must_use]
422 pub fn is_pong(&self) -> bool {
423 self.opcode == OpCode::Pong
424 }
425
426 #[must_use]
429 pub fn into_payload(self) -> Payload {
430 self.payload
431 }
432
433 pub fn as_payload(&self) -> &Payload {
435 &self.payload
436 }
437
438 pub fn as_text(&self) -> Option<&str> {
446 (self.opcode == OpCode::Text).then(|| {
449 assert!(
450 self.payload.utf8_validated || utf8::parse_str(&self.payload).is_ok(),
451 "called as_text on message created from payload with invalid utf-8"
452 );
453 unsafe { std::str::from_utf8_unchecked(&self.payload) }
454 })
455 }
456
457 pub fn as_close(&self) -> Option<(CloseCode, &str)> {
460 (self.opcode == OpCode::Close).then(|| {
461 let code = if self.payload.is_empty() {
462 CloseCode::NO_STATUS_RECEIVED
463 } else {
464 unsafe {
467 CloseCode::try_from(u16::from_be_bytes(
468 self.payload
469 .get_unchecked(0..2)
470 .try_into()
471 .unwrap_unchecked(),
472 ))
473 .unwrap_unchecked()
474 }
475 };
476
477 let reason =
479 unsafe { std::str::from_utf8_unchecked(self.payload.get(2..).unwrap_or_default()) };
480
481 (code, reason)
482 })
483 }
484
485 pub(super) fn into_frames(self, frame_size: usize) -> MessageFrames {
488 MessageFrames {
489 frame_size,
490 payload: self.payload,
491 opcode: self.opcode,
492 }
493 }
494}
495
496pub(super) struct MessageFrames {
498 frame_size: usize,
500 payload: Payload,
502 opcode: OpCode,
504}
505
506impl Iterator for MessageFrames {
507 type Item = Frame;
508
509 fn next(&mut self) -> Option<Self::Item> {
510 let is_empty = self.payload.is_empty() && self.opcode == OpCode::Continuation;
511
512 (!is_empty).then(|| {
513 let payload = self
514 .payload
515 .split_to(self.frame_size.min(self.payload.len()));
516
517 Frame {
518 opcode: replace(&mut self.opcode, OpCode::Continuation),
519 is_final: self.payload.is_empty(),
520 payload,
521 }
522 })
523 }
524}
525
526#[derive(Debug, Clone, Copy)]
531pub struct Limits {
532 pub(super) max_payload_len: usize,
534}
535
536impl Limits {
537 #[must_use]
539 pub fn unlimited() -> Self {
540 Self {
541 max_payload_len: usize::MAX,
542 }
543 }
544
545 #[must_use]
549 pub fn max_payload_len(mut self, size: Option<usize>) -> Self {
550 self.set_max_payload_len(size);
551
552 self
553 }
554
555 pub fn set_max_payload_len(&mut self, size: Option<usize>) {
557 self.max_payload_len = size.unwrap_or(usize::MAX);
558 }
559}
560
561impl Default for Limits {
562 fn default() -> Self {
563 Self {
564 max_payload_len: 64 * 1024 * 1024,
565 }
566 }
567}
568
569#[derive(Debug, Clone, Copy)]
574pub struct Config {
575 pub(super) frame_size: usize,
580 pub(super) flush_threshold: usize,
583}
584
585impl Config {
586 #[must_use]
595 pub fn frame_size(mut self, frame_size: usize) -> Self {
596 assert_ne!(frame_size, 0, "frame_size must be non-zero");
597 self.frame_size = frame_size;
598
599 self
600 }
601
602 #[must_use]
605 pub fn flush_threshold(mut self, threshold: usize) -> Self {
606 self.flush_threshold = threshold;
607
608 self
609 }
610}
611
612impl Default for Config {
613 fn default() -> Self {
614 Self {
615 frame_size: 4 * 1024 * 1024,
616 flush_threshold: 8 * 1024,
617 }
618 }
619}
620
621#[derive(Debug, PartialEq, Eq, Clone, Copy)]
623pub(crate) enum Role {
624 Client,
626 Server,
628}
629
630#[derive(Debug, PartialEq)]
632pub(super) enum StreamState {
633 Active,
635 ClosedByPeer,
638 ClosedByUs,
640 CloseAcknowledged,
643}
644
645#[derive(Clone, Debug)]
647pub(super) struct Frame {
648 pub opcode: OpCode,
650 pub is_final: bool,
652 pub payload: Payload,
654}
655
656impl Frame {
657 #[allow(clippy::declare_interior_mutable_const)]
659 pub const DEFAULT_CLOSE: Self = Self {
660 opcode: OpCode::Close,
661 is_final: true,
662 payload: Payload::from_static(&CloseCode::NORMAL_CLOSURE.0.get().to_be_bytes()),
663 };
664
665 pub fn encode<'a>(&self, out: &'a mut [u8; 14]) -> &'a mut [u8; 4] {
668 out[0] = (u8::from(self.is_final) << 7) | u8::from(self.opcode);
669 let mask_slice = if u16::try_from(self.payload.len()).is_err() {
670 out[1] = 127;
671 let len = u64::try_from(self.payload.len()).unwrap();
672 out[2..10].copy_from_slice(&len.to_be_bytes());
673 &mut out[10..14]
674 } else if self.payload.len() > 125 {
675 out[1] = 126;
676 let len = u16::try_from(self.payload.len()).expect("checked by previous branch");
677 out[2..4].copy_from_slice(&len.to_be_bytes());
678 &mut out[4..8]
679 } else {
680 out[1] = u8::try_from(self.payload.len()).expect("checked by previous branch");
681 &mut out[2..6]
682 };
683 mask_slice.try_into().unwrap()
684 }
685}
686
687impl From<Message> for Frame {
688 fn from(value: Message) -> Self {
689 Self {
690 opcode: value.opcode,
691 is_final: true,
692 payload: value.payload,
693 }
694 }
695}
696
697impl From<&ProtocolError> for Frame {
698 fn from(val: &ProtocolError) -> Self {
699 match val {
700 ProtocolError::InvalidUtf8 => {
701 Message::close(Some(CloseCode::INVALID_FRAME_PAYLOAD_DATA), "invalid utf8")
702 }
703 _ => Message::close(Some(CloseCode::PROTOCOL_ERROR), val.as_str()),
704 }
705 .into()
706 }
707}