Skip to main content

session_rs/ws/
mod.rs

1pub mod error;
2pub mod handshake;
3pub use error::{Error, Result};
4
5use std::{
6    hash::{Hash, Hasher},
7    sync::Arc,
8};
9use tokio::{
10    io::{AsyncReadExt, AsyncWriteExt},
11    sync::Mutex,
12};
13
14#[derive(Debug, Clone)]
15pub enum Frame {
16    Text(String),
17    Binary(Vec<u8>),
18    Ping,
19    Pong,
20    Close,
21}
22
23pub struct WebSocket {
24    pub(crate) reader: Arc<Mutex<tokio::net::tcp::OwnedReadHalf>>,
25    pub(crate) writer: Arc<Mutex<tokio::net::tcp::OwnedWriteHalf>>,
26    pub(crate) id: u64,
27    pub(crate) is_server: bool,
28}
29
30impl Clone for WebSocket {
31    fn clone(&self) -> Self {
32        WebSocket {
33            reader: self.reader.clone(),
34            writer: self.writer.clone(),
35            is_server: self.is_server.clone(),
36            id: self.id,
37        }
38    }
39}
40
41impl PartialEq for WebSocket {
42    fn eq(&self, other: &Self) -> bool {
43        self.id == other.id
44    }
45}
46
47impl Eq for WebSocket {}
48
49impl Hash for WebSocket {
50    fn hash<H: Hasher>(&self, state: &mut H) {
51        self.id.hash(state);
52    }
53}
54
55impl WebSocket {
56    async fn send_frame(&self, opcode: u8, payload: &[u8]) -> Result<()> {
57        let mut writer = self.writer.lock().await;
58
59        let mut header = Vec::with_capacity(10);
60        let mask_bit = if self.is_server { 0x80 } else { 0x00 };
61        header.push(0x80 | opcode); // FIN + opcode
62
63        let len = payload.len();
64        if len < 126 {
65            header.push((len as u8) | mask_bit);
66        } else if len <= 0xFFFF {
67            header.push(126 | mask_bit);
68            header.extend_from_slice(&(len as u16).to_be_bytes());
69        } else {
70            header.push(127 | mask_bit);
71            header.extend_from_slice(&(len as u64).to_be_bytes());
72        }
73
74        if self.is_server {
75            // Generate 4-byte mask key
76            let mask_key: [u8; 4] = rand::random();
77            header.extend_from_slice(&mask_key);
78
79            // Mask the payload
80            let mut masked_payload = payload.to_vec();
81            for i in 0..masked_payload.len() {
82                masked_payload[i] ^= mask_key[i % 4];
83            }
84
85            writer.write_all(&header).await?;
86            writer.write_all(&masked_payload).await?;
87        } else {
88            writer.write_all(&header).await?;
89            writer.write_all(payload).await?;
90        }
91
92        writer.flush().await?;
93        Ok(())
94    }
95}
96
97impl WebSocket {
98    pub async fn send(&self, msg: &str) -> Result<()> {
99        self.send_frame(0x1, msg.as_bytes()).await
100    }
101
102    pub async fn send_text_payload(&self, payload: &[u8]) -> Result<()> {
103        self.send_frame(0x1, payload).await
104    }
105
106    pub async fn send_bin(&self, payload: &[u8]) -> Result<()> {
107        self.send_frame(0x2, payload).await
108    }
109
110    pub async fn send_ping(&self) -> Result<()> {
111        self.send_frame(0x9, &[]).await
112    }
113
114    pub async fn send_pong(&self) -> Result<()> {
115        self.send_frame(0xA, &[]).await
116    }
117
118    pub async fn close(&self) -> Result<()> {
119        self.send_frame(0x8, &[]).await
120    }
121
122    pub fn start_ping_loop(&self) {
123        let s = self.clone();
124        tokio::task::spawn(async move {
125            let mut interval = tokio::time::interval(std::time::Duration::from_secs(15));
126            loop {
127                interval.tick().await;
128                if s.send_ping().await.is_err() {
129                    break;
130                }
131            }
132        });
133    }
134}
135
136impl WebSocket {
137    /// Read a full WebSocket frame (handling masking and control frames)
138    /// Returns (opcode, payload)
139    pub async fn read_frame(&self) -> Result<(bool, u8, Vec<u8>)> {
140        let mut reader = self.reader.lock().await;
141
142        // --- 1. Read first 2-byte header ---
143        let mut header = [0u8; 2];
144        reader.read_exact(&mut header).await?;
145
146        let fin = header[0] & 0x80 != 0;
147        let opcode = header[0] & 0x0F;
148        let masked = header[1] & 0x80 != 0;
149        let mut payload_len = (header[1] & 0x7F) as u64;
150
151        // --- 2. Read extended payload length if necessary ---
152        if payload_len == 126 {
153            let mut buf = [0u8; 2];
154            reader.read_exact(&mut buf).await?;
155            payload_len = u16::from_be_bytes(buf) as u64;
156        } else if payload_len == 127 {
157            let mut buf = [0u8; 8];
158            reader.read_exact(&mut buf).await?;
159            payload_len = u64::from_be_bytes(buf);
160        }
161
162        let payload = if masked {
163            // --- 3. Read mask key ---
164            let mut mask = [0u8; 4];
165            reader.read_exact(&mut mask).await?;
166            let mut payload = vec![0u8; payload_len as usize];
167            if payload_len > 0 {
168                reader.read_exact(&mut payload).await?;
169                for i in 0..payload.len() {
170                    payload[i] ^= mask[i % 4];
171                }
172            }
173            payload
174        } else {
175            // Per spec, client-to-server frames MUST be masked
176            if !self.is_server {
177                self.close().await.ok();
178                return Err(Error::InvalidFrame(
179                    "Received unmasked frame from client".into(),
180                ));
181            }
182
183            let mut payload = vec![0u8; payload_len as usize];
184            if payload_len > 0 {
185                reader.read_exact(&mut payload).await?;
186            }
187            payload
188        };
189
190        // --- 6. Return opcode + payload ---
191        Ok((fin, opcode, payload))
192    }
193
194    pub async fn read(&self) -> Result<Frame> {
195        let (fin, opcode, mut payload) = self.read_frame().await?;
196
197        if !fin {
198            // Continuation loop
199            while let (fin, o, mut p) = self.read_frame().await?
200                && !fin
201            {
202                match o {
203                    // Continuation
204                    0x0 => payload.append(&mut p),
205                    // Close
206                    0x8 => {
207                        self.close().await.ok();
208                    }
209                    // Ping
210                    0x9 => {
211                        self.send_pong().await.ok();
212                    }
213                    // Pong
214                    0xA => {}
215                    _ => {
216                        self.close().await.ok();
217                        return Err(Error::InvalidFrame(format!("Unknown opcode: {opcode}")));
218                    }
219                }
220            }
221        }
222
223        match opcode {
224            // Close
225            0x8 => {
226                self.close().await.ok();
227                Ok(Frame::Close)
228            }
229
230            // Ping
231            0x9 => {
232                self.send_pong().await.ok();
233                Ok(Frame::Ping)
234            }
235
236            // Pong
237            0xA => Ok(Frame::Pong),
238
239            // Text
240            0x1 => Ok(Frame::Text(String::from_utf8(payload)?)),
241
242            // Binary
243            0x2 => Ok(Frame::Binary(payload)),
244
245            _ => {
246                self.close().await.ok();
247                Err(Error::InvalidFrame(format!("Unknown opcode: {opcode}")))
248            }
249        }
250    }
251}