rusty_socket/
lib.rs

1//! A minimal websocket implementation
2
3use smol::io::{AsyncRead, AsyncWrite};
4use smol::prelude::*;
5use std::pin::Pin;
6use thiserror::Error;
7
8mod frame;
9mod handshake;
10mod mask;
11
12pub use frame::{Frame, OpCode};
13use handshake::{client_handshake, server_handshake};
14
15/// Represents errors that can occur in WebSocket operations.
16#[derive(Debug, Error)]
17pub enum Error {
18    /// An I/O error occurred.
19    #[error("I/O error: {0}")]
20    Io(#[from] std::io::Error),
21    /// A WebSocket protocol error occurred.
22    #[error("WebSocket protocol error: {0}")]
23    Protocol(String),
24    /// The WebSocket connection was closed.
25    #[error("WebSocket connection closed")]
26    ConnectionClosed,
27}
28
29/// A Result type alias for WebSocket operations.
30pub type Result<T> = std::result::Result<T, Error>;
31
32/// Represents a WebSocket connection.
33pub struct WebSocket<S> {
34    stream: S,
35    is_client: bool,
36}
37
38impl<S> WebSocket<S>
39where
40    S: AsyncRead + AsyncWrite + Unpin,
41{
42    /// Accepts a WebSocket connection as a server.
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if the handshake fails.
47    pub async fn accept(stream: S) -> Result<Self> {
48        let mut ws = WebSocket { stream, is_client: false };
49        server_handshake(&mut ws.stream).await?;
50        Ok(ws)
51    }
52
53    /// Connects to a WebSocket server as a client.
54    ///
55    /// # Errors
56    ///
57    /// Returns an error if the handshake fails.
58    pub async fn connect(stream: S) -> Result<Self> {
59        let mut ws = WebSocket { stream, is_client: true };
60        client_handshake(&mut ws.stream).await?;
61        Ok(ws)
62    }
63
64    /// Sends a WebSocket frame.
65    ///
66    /// # Errors
67    ///
68    /// Returns an error if sending the frame fails.
69    pub async fn send(&mut self, frame: Frame) -> Result<()> {
70        let mut data = frame.to_bytes();
71        if self.is_client {
72            mask::mask_frame(&mut data);
73        }
74        self.stream.write_all(&data).await?;
75        Ok(())
76    }
77
78    /// Receives a WebSocket frame.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if receiving the frame fails or if the frame is invalid.
83    pub async fn receive(&mut self) -> Result<Frame> {
84        let frame = Frame::read_from(&mut self.stream).await?;
85        if !self.is_client && frame.is_masked() {
86            return Err(Error::Protocol("Client frames must be masked".into()));
87        }
88        if self.is_client && !frame.is_masked() {
89            return Err(Error::Protocol("Server frames must not be masked".into()));
90        }
91        Ok(frame)
92    }
93
94    /// Closes the WebSocket connection.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if closing the connection fails.
99    pub async fn close(mut self) -> Result<()> {
100        let close_frame = Frame::close(None);
101        self.send(close_frame).await?;
102        // Wait for the close frame from the other side
103        loop {
104            match self.receive().await {
105                Ok(frame) if frame.is_close() => break,
106                Ok(_) => continue,
107                Err(Error::ConnectionClosed) => break,
108                Err(e) => return Err(e),
109            }
110        }
111        Ok(())
112    }
113}
114
115impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocket<S> {
116    fn poll_read(
117        mut self: Pin<&mut Self>,
118        cx: &mut std::task::Context<'_>,
119        buf: &mut [u8],
120    ) -> std::task::Poll<std::io::Result<usize>> {
121        Pin::new(&mut self.stream).poll_read(cx, buf)
122    }
123}
124
125impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocket<S> {
126    fn poll_write(
127        mut self: Pin<&mut Self>,
128        cx: &mut std::task::Context<'_>,
129        buf: &[u8],
130    ) -> std::task::Poll<std::io::Result<usize>> {
131        Pin::new(&mut self.stream).poll_write(cx, buf)
132    }
133
134    fn poll_flush(
135        mut self: Pin<&mut Self>,
136        cx: &mut std::task::Context<'_>,
137    ) -> std::task::Poll<std::io::Result<()>> {
138        Pin::new(&mut self.stream).poll_flush(cx)
139    }
140
141    fn poll_close(
142        mut self: Pin<&mut Self>,
143        cx: &mut std::task::Context<'_>,
144    ) -> std::task::Poll<std::io::Result<()>> {
145        Pin::new(&mut self.stream).poll_close(cx)
146    }
147}