1use smol::io::{AsyncRead, AsyncWrite};
4use smol::prelude::*;
5use std::pin::Pin;
6use thiserror::Error;
7
8mod frame;
9mod handshake;
10mod mask;
11
12pub use frame::{Frame, OpCode};
13use handshake::{client_handshake, server_handshake};
14
15#[derive(Debug, Error)]
17pub enum Error {
18 #[error("I/O error: {0}")]
20 Io(#[from] std::io::Error),
21 #[error("WebSocket protocol error: {0}")]
23 Protocol(String),
24 #[error("WebSocket connection closed")]
26 ConnectionClosed,
27}
28
29pub type Result<T> = std::result::Result<T, Error>;
31
32pub struct WebSocket<S> {
34 stream: S,
35 is_client: bool,
36}
37
38impl<S> WebSocket<S>
39where
40 S: AsyncRead + AsyncWrite + Unpin,
41{
42 pub async fn accept(stream: S) -> Result<Self> {
48 let mut ws = WebSocket { stream, is_client: false };
49 server_handshake(&mut ws.stream).await?;
50 Ok(ws)
51 }
52
53 pub async fn connect(stream: S) -> Result<Self> {
59 let mut ws = WebSocket { stream, is_client: true };
60 client_handshake(&mut ws.stream).await?;
61 Ok(ws)
62 }
63
64 pub async fn send(&mut self, frame: Frame) -> Result<()> {
70 let mut data = frame.to_bytes();
71 if self.is_client {
72 mask::mask_frame(&mut data);
73 }
74 self.stream.write_all(&data).await?;
75 Ok(())
76 }
77
78 pub async fn receive(&mut self) -> Result<Frame> {
84 let frame = Frame::read_from(&mut self.stream).await?;
85 if !self.is_client && frame.is_masked() {
86 return Err(Error::Protocol("Client frames must be masked".into()));
87 }
88 if self.is_client && !frame.is_masked() {
89 return Err(Error::Protocol("Server frames must not be masked".into()));
90 }
91 Ok(frame)
92 }
93
94 pub async fn close(mut self) -> Result<()> {
100 let close_frame = Frame::close(None);
101 self.send(close_frame).await?;
102 loop {
104 match self.receive().await {
105 Ok(frame) if frame.is_close() => break,
106 Ok(_) => continue,
107 Err(Error::ConnectionClosed) => break,
108 Err(e) => return Err(e),
109 }
110 }
111 Ok(())
112 }
113}
114
115impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocket<S> {
116 fn poll_read(
117 mut self: Pin<&mut Self>,
118 cx: &mut std::task::Context<'_>,
119 buf: &mut [u8],
120 ) -> std::task::Poll<std::io::Result<usize>> {
121 Pin::new(&mut self.stream).poll_read(cx, buf)
122 }
123}
124
125impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocket<S> {
126 fn poll_write(
127 mut self: Pin<&mut Self>,
128 cx: &mut std::task::Context<'_>,
129 buf: &[u8],
130 ) -> std::task::Poll<std::io::Result<usize>> {
131 Pin::new(&mut self.stream).poll_write(cx, buf)
132 }
133
134 fn poll_flush(
135 mut self: Pin<&mut Self>,
136 cx: &mut std::task::Context<'_>,
137 ) -> std::task::Poll<std::io::Result<()>> {
138 Pin::new(&mut self.stream).poll_flush(cx)
139 }
140
141 fn poll_close(
142 mut self: Pin<&mut Self>,
143 cx: &mut std::task::Context<'_>,
144 ) -> std::task::Poll<std::io::Result<()>> {
145 Pin::new(&mut self.stream).poll_close(cx)
146 }
147}