1pub mod error;
2pub mod handshake;
3pub use error::{Error, Result};
4
5use std::{
6 hash::{Hash, Hasher},
7 sync::Arc,
8};
9use tokio::{
10 io::{AsyncReadExt, AsyncWriteExt},
11 sync::Mutex,
12};
13
14#[derive(Debug, Clone)]
15pub enum Frame {
16 Text(String),
17 Binary(Vec<u8>),
18 Ping,
19 Pong,
20 Close,
21}
22
23pub struct WebSocket {
24 pub(crate) reader: Arc<Mutex<tokio::net::tcp::OwnedReadHalf>>,
25 pub(crate) writer: Arc<Mutex<tokio::net::tcp::OwnedWriteHalf>>,
26 pub(crate) id: u64,
27 pub(crate) is_server: bool,
28}
29
30impl Clone for WebSocket {
31 fn clone(&self) -> Self {
32 WebSocket {
33 reader: self.reader.clone(),
34 writer: self.writer.clone(),
35 is_server: self.is_server.clone(),
36 id: self.id,
37 }
38 }
39}
40
41impl PartialEq for WebSocket {
42 fn eq(&self, other: &Self) -> bool {
43 self.id == other.id
44 }
45}
46
47impl Eq for WebSocket {}
48
49impl Hash for WebSocket {
50 fn hash<H: Hasher>(&self, state: &mut H) {
51 self.id.hash(state);
52 }
53}
54
55impl WebSocket {
56 async fn send_frame(&self, opcode: u8, payload: &[u8]) -> Result<()> {
57 let mut writer = self.writer.lock().await;
58
59 let mut header = Vec::with_capacity(10);
60 let mask_bit = if self.is_server { 0x80 } else { 0x00 };
61 header.push(0x80 | opcode); let len = payload.len();
64 if len < 126 {
65 header.push((len as u8) | mask_bit);
66 } else if len <= 0xFFFF {
67 header.push(126 | mask_bit);
68 header.extend_from_slice(&(len as u16).to_be_bytes());
69 } else {
70 header.push(127 | mask_bit);
71 header.extend_from_slice(&(len as u64).to_be_bytes());
72 }
73
74 if self.is_server {
75 let mask_key: [u8; 4] = rand::random();
77 header.extend_from_slice(&mask_key);
78
79 let mut masked_payload = payload.to_vec();
81 for i in 0..masked_payload.len() {
82 masked_payload[i] ^= mask_key[i % 4];
83 }
84
85 writer.write_all(&header).await?;
86 writer.write_all(&masked_payload).await?;
87 } else {
88 writer.write_all(&header).await?;
89 writer.write_all(payload).await?;
90 }
91
92 writer.flush().await?;
93 Ok(())
94 }
95}
96
97impl WebSocket {
98 pub async fn send(&self, msg: &str) -> Result<()> {
99 self.send_frame(0x1, msg.as_bytes()).await
100 }
101
102 pub async fn send_text_payload(&self, payload: &[u8]) -> Result<()> {
103 self.send_frame(0x1, payload).await
104 }
105
106 pub async fn send_bin(&self, payload: &[u8]) -> Result<()> {
107 self.send_frame(0x2, payload).await
108 }
109
110 pub async fn send_ping(&self) -> Result<()> {
111 self.send_frame(0x9, &[]).await
112 }
113
114 pub async fn send_pong(&self) -> Result<()> {
115 self.send_frame(0xA, &[]).await
116 }
117
118 pub async fn close(&self) -> Result<()> {
119 self.send_frame(0x8, &[]).await
120 }
121
122 pub fn start_ping_loop(&self) {
123 let s = self.clone();
124 tokio::task::spawn(async move {
125 let mut interval = tokio::time::interval(std::time::Duration::from_secs(15));
126 loop {
127 interval.tick().await;
128 if s.send_ping().await.is_err() {
129 break;
130 }
131 }
132 });
133 }
134}
135
136impl WebSocket {
137 pub async fn read_frame(&self) -> Result<(bool, u8, Vec<u8>)> {
140 let mut reader = self.reader.lock().await;
141
142 let mut header = [0u8; 2];
144 reader.read_exact(&mut header).await?;
145
146 let fin = header[0] & 0x80 != 0;
147 let opcode = header[0] & 0x0F;
148 let masked = header[1] & 0x80 != 0;
149 let mut payload_len = (header[1] & 0x7F) as u64;
150
151 if payload_len == 126 {
153 let mut buf = [0u8; 2];
154 reader.read_exact(&mut buf).await?;
155 payload_len = u16::from_be_bytes(buf) as u64;
156 } else if payload_len == 127 {
157 let mut buf = [0u8; 8];
158 reader.read_exact(&mut buf).await?;
159 payload_len = u64::from_be_bytes(buf);
160 }
161
162 let payload = if masked {
163 let mut mask = [0u8; 4];
165 reader.read_exact(&mut mask).await?;
166 let mut payload = vec![0u8; payload_len as usize];
167 if payload_len > 0 {
168 reader.read_exact(&mut payload).await?;
169 for i in 0..payload.len() {
170 payload[i] ^= mask[i % 4];
171 }
172 }
173 payload
174 } else {
175 if !self.is_server {
177 self.close().await.ok();
178 return Err(Error::InvalidFrame(
179 "Received unmasked frame from client".into(),
180 ));
181 }
182
183 let mut payload = vec![0u8; payload_len as usize];
184 if payload_len > 0 {
185 reader.read_exact(&mut payload).await?;
186 }
187 payload
188 };
189
190 Ok((fin, opcode, payload))
192 }
193
194 pub async fn read(&self) -> Result<Frame> {
195 let (fin, opcode, mut payload) = self.read_frame().await?;
196
197 if !fin {
198 while let (fin, o, mut p) = self.read_frame().await?
200 && !fin
201 {
202 match o {
203 0x0 => payload.append(&mut p),
205 0x8 => {
207 self.close().await.ok();
208 }
209 0x9 => {
211 self.send_pong().await.ok();
212 }
213 0xA => {}
215 _ => {
216 self.close().await.ok();
217 return Err(Error::InvalidFrame(format!("Unknown opcode: {opcode}")));
218 }
219 }
220 }
221 }
222
223 match opcode {
224 0x8 => {
226 self.close().await.ok();
227 Ok(Frame::Close)
228 }
229
230 0x9 => {
232 self.send_pong().await.ok();
233 Ok(Frame::Ping)
234 }
235
236 0xA => Ok(Frame::Pong),
238
239 0x1 => Ok(Frame::Text(String::from_utf8(payload)?)),
241
242 0x2 => Ok(Frame::Binary(payload)),
244
245 _ => {
246 self.close().await.ok();
247 Err(Error::InvalidFrame(format!("Unknown opcode: {opcode}")))
248 }
249 }
250 }
251}