1mod frame;
4
5#[cfg(feature = "client")]
6mod client;
7
8#[cfg(feature = "server")]
9mod server;
10
11use std::{
12 borrow::Cow,
13 io,
14 mem::MaybeUninit,
15 pin::Pin,
16 task::{Context, Poll},
17};
18
19use base64::Engine;
20use bytes::{Buf, BufMut, Bytes, BytesMut};
21use futures::{Sink, SinkExt, Stream, StreamExt, ready};
22use sha1::{Digest, Sha1};
23use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24
25use self::frame::{Frame, InvalidFrame};
26
27use crate::connection::Upgraded;
28
29#[cfg(feature = "server")]
30use crate::{Error, server::IncomingRequest};
31
32#[cfg(feature = "client")]
33#[cfg_attr(docsrs, doc(cfg(feature = "client")))]
34pub use self::client::{ClientHandshake, ClientHandshakeBuilder};
35
36#[cfg(feature = "server")]
37#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
38pub use self::server::{FutureServer, ServerHandshake};
39
40pub fn create_key() -> String {
42 base64::prelude::BASE64_STANDARD.encode(&rand::random::<[u8; 16]>()[..])
43}
44
45pub fn create_accept_token(key: &[u8]) -> String {
47 let suffix = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
48
49 let mut input = Vec::with_capacity(key.len() + suffix.len());
50
51 input.extend_from_slice(key);
52 input.extend_from_slice(suffix);
53
54 let hash = Sha1::digest(&input);
55
56 base64::prelude::BASE64_STANDARD.encode(hash.as_slice())
57}
58
59enum InternalError {
61 ProtocolError,
62 InvalidString,
63 MessageSizeExceeded,
64 UnexpectedEof,
65 IO(io::Error),
66}
67
68impl InternalError {
69 fn to_close_message(&self) -> Option<CloseMessage> {
71 let status = match self {
72 Self::ProtocolError => CloseMessage::STATUS_PROTOCOL_ERROR,
73 Self::InvalidString => CloseMessage::STATUS_INVALID_DATA,
74 Self::MessageSizeExceeded => CloseMessage::STATUS_TOO_BIG,
75 _ => return None,
76 };
77
78 Some(CloseMessage::new_static(status, ""))
79 }
80}
81
82impl From<InvalidFrame> for InternalError {
83 fn from(_: InvalidFrame) -> Self {
84 Self::ProtocolError
85 }
86}
87
88impl From<io::Error> for InternalError {
89 fn from(err: io::Error) -> Self {
90 Self::IO(err)
91 }
92}
93
94#[derive(Copy, Clone, Eq, PartialEq)]
96pub enum AgentRole {
97 Client,
98 Server,
99}
100
101#[derive(Clone)]
103pub enum Message {
104 Text(String),
105 Data(Bytes),
106 Ping(Bytes),
107 Pong(Bytes),
108 Close(CloseMessage),
109}
110
111impl Message {
112 fn into_frame(self) -> Frame {
114 match self {
115 Self::Text(text) => Frame::new(Frame::OPCODE_TEXT, text.into(), true),
116 Self::Data(data) => Frame::new(Frame::OPCODE_BINARY, data, true),
117 Self::Ping(data) => Frame::new(Frame::OPCODE_PING, data, true),
118 Self::Pong(data) => Frame::new(Frame::OPCODE_PONG, data, true),
119 Self::Close(close) => close.into_frame(),
120 }
121 }
122}
123
124impl From<CloseMessage> for Message {
125 #[inline]
126 fn from(close: CloseMessage) -> Self {
127 Self::Close(close)
128 }
129}
130
131#[derive(Clone)]
133pub struct CloseMessage {
134 status: u16,
135 message: Cow<'static, str>,
136}
137
138impl CloseMessage {
139 pub const STATUS_OK: u16 = 1000;
140 pub const STATUS_GOING_AWAY: u16 = 1001;
141 pub const STATUS_PROTOCOL_ERROR: u16 = 1002;
142 pub const STATUS_UNEXPECTED_DATA: u16 = 1003;
143 pub const STATUS_INVALID_DATA: u16 = 1007;
144 pub const STATUS_TOO_BIG: u16 = 1009;
145
146 pub fn new<T>(status: u16, msg: T) -> Self
149 where
150 T: ToString,
151 {
152 Self {
153 status,
154 message: Cow::Owned(msg.to_string()),
155 }
156 }
157
158 #[inline]
161 pub const fn new_static(status: u16, msg: &'static str) -> Self {
162 Self {
163 status,
164 message: Cow::Borrowed(msg),
165 }
166 }
167
168 #[inline]
170 pub fn status(&self) -> u16 {
171 self.status
172 }
173
174 #[inline]
176 pub fn message(&self) -> &str {
177 &self.message
178 }
179
180 fn into_frame(self) -> Frame {
182 let mut data = BytesMut::with_capacity(self.message.len() + 2);
183
184 data.put_u16(self.status);
185 data.extend_from_slice(self.message.as_bytes());
186
187 Frame::new(Frame::OPCODE_CLOSE, data.freeze(), true)
188 }
189}
190
191pub struct WebSocket {
193 inner: Option<FrameSocket>,
194 current_msg_type: Option<u8>,
195 current_msg_data: Vec<u8>,
196 input_buffer_capacity: usize,
197 closed: bool,
198}
199
200impl WebSocket {
201 #[cfg(feature = "client")]
203 #[cfg_attr(docsrs, doc(cfg(feature = "client")))]
204 #[inline]
205 pub fn client() -> ClientHandshakeBuilder {
206 ClientHandshake::builder()
207 }
208
209 #[cfg(feature = "server")]
211 #[cfg_attr(docsrs, doc(cfg(feature = "server")))]
212 #[inline]
213 pub fn server(request: IncomingRequest) -> Result<ServerHandshake, Error> {
214 ServerHandshake::new(request)
215 }
216
217 #[inline]
219 pub fn new(upgraded: Upgraded, agent_role: AgentRole, input_buffer_capacity: usize) -> Self {
220 let inner = FrameSocket::new(upgraded, agent_role, input_buffer_capacity);
221
222 Self {
223 inner: Some(inner),
224 current_msg_type: None,
225 current_msg_data: Vec::new(),
226 input_buffer_capacity,
227 closed: false,
228 }
229 }
230
231 fn process_frame(&mut self, frame: Frame) -> Result<Option<Message>, InternalError> {
233 let opcode = frame.opcode();
234 let fin = frame.fin();
235 let data = frame.into_payload();
236
237 match opcode {
238 Frame::OPCODE_CONTINUATION => self.process_continuation_frame(&data, fin),
239 Frame::OPCODE_BINARY => self.process_binary_frame(data, fin),
240 Frame::OPCODE_TEXT => self.process_text_frame(data, fin),
241 Frame::OPCODE_PING => self.process_ping_frame(data, fin),
242 Frame::OPCODE_PONG => self.process_pong_frame(data, fin),
243 Frame::OPCODE_CLOSE => self.process_close_frame(data, fin),
244 _ => Err(InternalError::ProtocolError),
245 }
246 }
247
248 fn process_continuation_frame(
250 &mut self,
251 data: &[u8],
252 fin: bool,
253 ) -> Result<Option<Message>, InternalError> {
254 let msg_type = self.current_msg_type.ok_or(InternalError::ProtocolError)?;
255
256 if (self.current_msg_data.len() + data.len()) > self.input_buffer_capacity {
257 return Err(InternalError::MessageSizeExceeded);
258 }
259
260 self.current_msg_data.extend(data);
261
262 if !fin {
263 return Ok(None);
264 }
265
266 self.current_msg_type = None;
267
268 let data = Bytes::from(std::mem::take(&mut self.current_msg_data));
269
270 match msg_type {
271 Frame::OPCODE_BINARY => self.process_binary_frame(data, true),
272 Frame::OPCODE_TEXT => self.process_text_frame(data, true),
273 _ => unreachable!(),
274 }
275 }
276
277 fn process_binary_frame(
279 &mut self,
280 data: Bytes,
281 fin: bool,
282 ) -> Result<Option<Message>, InternalError> {
283 if self.current_msg_type.is_some() {
284 return Err(InternalError::ProtocolError);
285 }
286
287 if fin {
288 Ok(Some(Message::Data(data)))
289 } else {
290 self.current_msg_type = Some(Frame::OPCODE_BINARY);
291 self.current_msg_data = data.to_vec();
292
293 Ok(None)
294 }
295 }
296
297 fn process_text_frame(
299 &mut self,
300 data: Bytes,
301 fin: bool,
302 ) -> Result<Option<Message>, InternalError> {
303 if self.current_msg_type.is_some() {
304 return Err(InternalError::ProtocolError);
305 }
306
307 if fin {
308 let text = std::str::from_utf8(&data)
309 .map_err(|_| InternalError::InvalidString)?
310 .to_string();
311
312 Ok(Some(Message::Text(text)))
313 } else {
314 self.current_msg_type = Some(Frame::OPCODE_TEXT);
315 self.current_msg_data = data.to_vec();
316
317 Ok(None)
318 }
319 }
320
321 fn process_ping_frame(
323 &mut self,
324 data: Bytes,
325 fin: bool,
326 ) -> Result<Option<Message>, InternalError> {
327 if !fin {
328 return Err(InternalError::ProtocolError);
329 }
330
331 Ok(Some(Message::Ping(data)))
332 }
333
334 fn process_pong_frame(
336 &mut self,
337 data: Bytes,
338 fin: bool,
339 ) -> Result<Option<Message>, InternalError> {
340 if !fin {
341 return Err(InternalError::ProtocolError);
342 }
343
344 Ok(Some(Message::Pong(data)))
345 }
346
347 fn process_close_frame(
349 &mut self,
350 mut data: Bytes,
351 fin: bool,
352 ) -> Result<Option<Message>, InternalError> {
353 if !fin {
354 return Err(InternalError::ProtocolError);
355 }
356
357 let status = if data.len() < 2 {
358 data.clear();
360
361 1005
362 } else {
363 data.get_u16()
364 };
365
366 let msg = std::str::from_utf8(&data)
367 .map_err(|_| InternalError::InvalidString)?
368 .to_string();
369
370 let msg = CloseMessage::new(status, msg);
371
372 self.closed = true;
373
374 Ok(Some(msg.into()))
375 }
376
377 fn poll_next_inner(
379 &mut self,
380 cx: &mut Context<'_>,
381 ) -> Poll<Option<Result<Message, InternalError>>> {
382 loop {
383 if self.closed {
384 return Poll::Ready(None);
385 } else if let Some(inner) = self.inner.as_mut() {
386 if let Poll::Ready(ready) = inner.poll_next_unpin(cx) {
387 if let Some(frame) = ready.transpose()? {
388 if let Some(msg) = self.process_frame(frame)? {
389 return Poll::Ready(Some(Ok(msg)));
390 }
391 } else {
392 return Poll::Ready(None);
393 }
394 } else {
395 return Poll::Pending;
396 }
397 } else {
398 return Poll::Ready(None);
399 }
400 }
401 }
402}
403
404impl Stream for WebSocket {
405 type Item = io::Result<Message>;
406
407 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
408 match ready!(self.poll_next_inner(cx)) {
409 Some(Ok(msg)) => Poll::Ready(Some(Ok(msg))),
410 Some(Err(err)) => {
411 if let Some(msg) = err.to_close_message() {
412 if let Some(mut inner) = self.inner.take() {
413 tokio::spawn(async move {
414 let _ = inner.send(msg.into_frame()).await;
415 });
416 }
417 }
418
419 let err = match err {
420 InternalError::UnexpectedEof => io::Error::from(io::ErrorKind::UnexpectedEof),
421 InternalError::IO(err) => err,
422 _ => io::Error::from(io::ErrorKind::InvalidData),
423 };
424
425 Poll::Ready(Some(Err(err)))
426 }
427 None => Poll::Ready(None),
428 }
429 }
430}
431
432impl Sink<Message> for WebSocket {
433 type Error = io::Error;
434
435 #[inline]
436 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
437 self.inner
438 .as_mut()
439 .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?
440 .poll_ready_unpin(cx)
441 }
442
443 fn start_send(mut self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> {
444 let this = &mut *self;
446
447 if this.closed {
448 return Err(io::Error::from(io::ErrorKind::BrokenPipe));
449 }
450
451 let inner = this
452 .inner
453 .as_mut()
454 .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?;
455
456 let frame = msg.into_frame();
457
458 this.closed |= frame.opcode() == Frame::OPCODE_CLOSE;
459
460 inner.start_send_unpin(frame)
461 }
462
463 #[inline]
464 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
465 self.inner
466 .as_mut()
467 .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?
468 .poll_flush_unpin(cx)
469 }
470
471 #[inline]
472 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
473 self.inner
474 .as_mut()
475 .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?
476 .poll_close_unpin(cx)
477 }
478}
479
480struct FrameSocket {
482 upgraded: Upgraded,
483 agent_role: AgentRole,
484 input_buffer: BytesMut,
485 output_buffer: BytesMut,
486 input_buffer_capacity: usize,
487 sent: usize,
488}
489
490impl FrameSocket {
491 #[inline]
493 fn new(upgraded: Upgraded, agent_role: AgentRole, input_buffer_capacity: usize) -> Self {
494 Self {
495 upgraded,
496 agent_role,
497 input_buffer: BytesMut::new(),
498 output_buffer: BytesMut::new(),
499 input_buffer_capacity,
500 sent: 0,
501 }
502 }
503}
504
505impl Stream for FrameSocket {
506 type Item = Result<Frame, InternalError>;
507
508 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
509 let mut buffer: [MaybeUninit<u8>; 8192] = unsafe { MaybeUninit::uninit().assume_init() };
510
511 let this = &mut *self;
513
514 loop {
515 if let Some(frame) = Frame::decode(&mut this.input_buffer, this.agent_role)? {
516 return Poll::Ready(Some(Ok(frame)));
517 } else if this.input_buffer.len() >= this.input_buffer_capacity {
518 return Poll::Ready(Some(Err(InternalError::MessageSizeExceeded)));
519 }
520
521 let available = this.input_buffer_capacity - this.input_buffer.len();
522 let read = available.min(buffer.len());
523
524 let mut buffer = ReadBuf::uninit(&mut buffer[..read]);
525
526 let pinned = Pin::new(&mut this.upgraded);
527
528 ready!(pinned.poll_read(cx, &mut buffer))?;
529
530 let filled = buffer.filled();
531
532 if !filled.is_empty() {
533 this.input_buffer.extend_from_slice(filled);
534 } else if this.input_buffer.is_empty() {
535 return Poll::Ready(None);
536 } else {
537 return Poll::Ready(Some(Err(InternalError::UnexpectedEof)));
538 }
539 }
540 }
541}
542
543impl Sink<Frame> for FrameSocket {
544 type Error = io::Error;
545
546 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
547 let this = &mut *self;
549
550 while this.sent < this.output_buffer.len() {
551 let pinned = Pin::new(&mut this.upgraded);
552
553 let len = ready!(pinned.poll_write(cx, &this.output_buffer[this.sent..]))?;
554
555 this.sent += len;
556 }
557
558 this.output_buffer.clear();
559 this.sent = 0;
560
561 Poll::Ready(Ok(()))
562 }
563
564 fn start_send(mut self: Pin<&mut Self>, frame: Frame) -> Result<(), Self::Error> {
565 let this = &mut *self;
567
568 frame.encode(&mut this.output_buffer, this.agent_role);
569
570 Ok(())
571 }
572
573 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
574 ready!(self.poll_ready_unpin(cx))?;
575
576 let pinned = Pin::new(&mut self.upgraded);
577
578 pinned.poll_flush(cx)
579 }
580
581 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
582 ready!(self.poll_ready_unpin(cx))?;
583
584 let pinned = Pin::new(&mut self.upgraded);
585
586 pinned.poll_shutdown(cx)
587 }
588}