Skip to main content

sigil_parser/
websocket.rs

1//! Native WebSocket client implementation for Sigil.
2//!
3//! Implements RFC 6455 WebSocket protocol without external dependencies.
4//! Supports both ws:// (plain) and wss:// (TLS) connections.
5
6use std::io::{Read, Write, BufRead, BufReader};
7use std::net::TcpStream;
8
9#[cfg(feature = "websocket")]
10use native_tls::TlsConnector;
11
12/// WebSocket operation error
13#[derive(Debug)]
14pub struct WebSocketError {
15    pub message: String,
16}
17
18impl WebSocketError {
19    pub fn new(msg: impl Into<String>) -> Self {
20        Self { message: msg.into() }
21    }
22}
23
24impl std::fmt::Display for WebSocketError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(f, "WebSocket error: {}", self.message)
27    }
28}
29
30impl std::error::Error for WebSocketError {}
31
32/// WebSocket frame opcodes (RFC 6455 Section 5.2)
33#[derive(Debug, Clone, Copy, PartialEq)]
34#[repr(u8)]
35pub enum Opcode {
36    Continuation = 0x0,
37    Text = 0x1,
38    Binary = 0x2,
39    Close = 0x8,
40    Ping = 0x9,
41    Pong = 0xA,
42}
43
44impl Opcode {
45    fn from_u8(val: u8) -> Option<Self> {
46        match val {
47            0x0 => Some(Opcode::Continuation),
48            0x1 => Some(Opcode::Text),
49            0x2 => Some(Opcode::Binary),
50            0x8 => Some(Opcode::Close),
51            0x9 => Some(Opcode::Ping),
52            0xA => Some(Opcode::Pong),
53            _ => None,
54        }
55    }
56}
57
58/// A WebSocket message
59#[derive(Debug, Clone)]
60pub enum Message {
61    Text(String),
62    Binary(Vec<u8>),
63    Close,
64    Ping(Vec<u8>),
65    Pong(Vec<u8>),
66}
67
68/// Stream that can be either plain TCP or TLS-wrapped
69enum Stream {
70    Plain(TcpStream),
71    #[cfg(feature = "websocket")]
72    Tls(native_tls::TlsStream<TcpStream>),
73}
74
75impl Read for Stream {
76    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
77        match self {
78            Stream::Plain(s) => s.read(buf),
79            #[cfg(feature = "websocket")]
80            Stream::Tls(s) => s.read(buf),
81        }
82    }
83}
84
85impl Write for Stream {
86    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
87        match self {
88            Stream::Plain(s) => s.write(buf),
89            #[cfg(feature = "websocket")]
90            Stream::Tls(s) => s.write(buf),
91        }
92    }
93
94    fn flush(&mut self) -> std::io::Result<()> {
95        match self {
96            Stream::Plain(s) => s.flush(),
97            #[cfg(feature = "websocket")]
98            Stream::Tls(s) => s.flush(),
99        }
100    }
101}
102
103/// Native WebSocket client
104pub struct WebSocket {
105    stream: Stream,
106}
107
108impl WebSocket {
109    /// Connect to a WebSocket server
110    ///
111    /// Supports ws:// and wss:// URLs
112    pub fn connect(url: &str) -> Result<Self, WebSocketError> {
113        // Parse URL
114        let (secure, host, port, path) = Self::parse_url(url)?;
115
116        // Establish TCP connection
117        let addr = format!("{}:{}", host, port);
118        let tcp_stream = TcpStream::connect(&addr)
119            .map_err(|e| WebSocketError::new(format!("TCP connection failed: {}", e)))?;
120
121        // Set timeouts
122        tcp_stream.set_read_timeout(Some(std::time::Duration::from_secs(30))).ok();
123        tcp_stream.set_write_timeout(Some(std::time::Duration::from_secs(30))).ok();
124
125        // Wrap in TLS if secure
126        let stream = if secure {
127            #[cfg(feature = "websocket")]
128            {
129                let connector = TlsConnector::new()
130                    .map_err(|e| WebSocketError::new(format!("TLS setup failed: {}", e)))?;
131                let tls_stream = connector.connect(&host, tcp_stream)
132                    .map_err(|e| WebSocketError::new(format!("TLS handshake failed: {}", e)))?;
133                Stream::Tls(tls_stream)
134            }
135            #[cfg(not(feature = "websocket"))]
136            {
137                return Err(WebSocketError::new("TLS support not compiled in"));
138            }
139        } else {
140            Stream::Plain(tcp_stream)
141        };
142
143        let mut ws = WebSocket { stream };
144
145        // Perform WebSocket handshake
146        ws.handshake(&host, port, &path)?;
147
148        Ok(ws)
149    }
150
151    /// Parse WebSocket URL into components
152    fn parse_url(url: &str) -> Result<(bool, String, u16, String), WebSocketError> {
153        let (secure, rest) = if url.starts_with("wss://") {
154            (true, &url[6..])
155        } else if url.starts_with("ws://") {
156            (false, &url[5..])
157        } else {
158            return Err(WebSocketError::new("URL must start with ws:// or wss://"));
159        };
160
161        // Split host:port from path
162        let (host_port, path) = match rest.find('/') {
163            Some(idx) => (&rest[..idx], &rest[idx..]),
164            None => (rest, "/"),
165        };
166
167        // Parse host and port
168        let (host, port) = match host_port.find(':') {
169            Some(idx) => {
170                let port_str = &host_port[idx + 1..];
171                let port = port_str.parse::<u16>()
172                    .map_err(|_| WebSocketError::new("Invalid port number"))?;
173                (host_port[..idx].to_string(), port)
174            }
175            None => (host_port.to_string(), if secure { 443 } else { 80 }),
176        };
177
178        Ok((secure, host, port, path.to_string()))
179    }
180
181    /// Perform WebSocket upgrade handshake (RFC 6455 Section 4)
182    fn handshake(&mut self, host: &str, port: u16, path: &str) -> Result<(), WebSocketError> {
183        // Generate random 16-byte key and base64 encode
184        let key_bytes: [u8; 16] = rand::random();
185        let key = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, key_bytes);
186
187        // Build HTTP upgrade request
188        let host_header = if port == 80 || port == 443 {
189            host.to_string()
190        } else {
191            format!("{}:{}", host, port)
192        };
193
194        let request = format!(
195            "GET {} HTTP/1.1\r\n\
196             Host: {}\r\n\
197             Upgrade: websocket\r\n\
198             Connection: Upgrade\r\n\
199             Sec-WebSocket-Key: {}\r\n\
200             Sec-WebSocket-Version: 13\r\n\
201             \r\n",
202            path, host_header, key
203        );
204
205        // Send request
206        self.stream.write_all(request.as_bytes())
207            .map_err(|e| WebSocketError::new(format!("Failed to send handshake: {}", e)))?;
208        self.stream.flush()
209            .map_err(|e| WebSocketError::new(format!("Failed to flush handshake: {}", e)))?;
210
211        // Read response
212        let mut reader = BufReader::new(&mut self.stream);
213        let mut response_line = String::new();
214        reader.read_line(&mut response_line)
215            .map_err(|e| WebSocketError::new(format!("Failed to read response: {}", e)))?;
216
217        // Check status
218        if !response_line.starts_with("HTTP/1.1 101") {
219            return Err(WebSocketError::new(format!("Handshake failed: {}", response_line.trim())));
220        }
221
222        // Read and validate headers
223        let expected_accept = Self::compute_accept_key(&key);
224        let mut found_accept = false;
225
226        loop {
227            let mut line = String::new();
228            reader.read_line(&mut line)
229                .map_err(|e| WebSocketError::new(format!("Failed to read headers: {}", e)))?;
230
231            let line = line.trim();
232            if line.is_empty() {
233                break; // End of headers
234            }
235
236            if let Some((name, value)) = line.split_once(':') {
237                let name = name.trim().to_lowercase();
238                let value = value.trim();
239
240                if name == "sec-websocket-accept" {
241                    if value != expected_accept {
242                        return Err(WebSocketError::new("Invalid Sec-WebSocket-Accept"));
243                    }
244                    found_accept = true;
245                }
246            }
247        }
248
249        if !found_accept {
250            return Err(WebSocketError::new("Missing Sec-WebSocket-Accept header"));
251        }
252
253        Ok(())
254    }
255
256    /// Compute the expected Sec-WebSocket-Accept value (RFC 6455 Section 4.2.2)
257    fn compute_accept_key(key: &str) -> String {
258        use sha1::{Sha1, Digest};
259
260        // Concatenate with magic GUID
261        let magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
262        let combined = format!("{}{}", key, magic);
263
264        // SHA-1 hash
265        let mut hasher = Sha1::new();
266        hasher.update(combined.as_bytes());
267        let hash = hasher.finalize();
268
269        // Base64 encode
270        base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash)
271    }
272
273    /// Send a text message
274    pub fn send_text(&mut self, text: &str) -> Result<(), WebSocketError> {
275        self.send_frame(Opcode::Text, text.as_bytes())
276    }
277
278    /// Send a binary message
279    pub fn send_binary(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
280        self.send_frame(Opcode::Binary, data)
281    }
282
283    /// Send a close frame
284    pub fn send_close(&mut self) -> Result<(), WebSocketError> {
285        self.send_frame(Opcode::Close, &[])
286    }
287
288    /// Send a WebSocket frame (RFC 6455 Section 5.2)
289    ///
290    /// Client frames MUST be masked per spec
291    fn send_frame(&mut self, opcode: Opcode, payload: &[u8]) -> Result<(), WebSocketError> {
292        let mut frame = Vec::with_capacity(14 + payload.len());
293
294        // First byte: FIN + opcode
295        frame.push(0x80 | (opcode as u8)); // FIN bit set
296
297        // Second byte: MASK + payload length
298        let len = payload.len();
299        if len < 126 {
300            frame.push(0x80 | len as u8); // Mask bit set
301        } else if len < 65536 {
302            frame.push(0x80 | 126);
303            frame.push((len >> 8) as u8);
304            frame.push(len as u8);
305        } else {
306            frame.push(0x80 | 127);
307            for i in (0..8).rev() {
308                frame.push((len >> (i * 8)) as u8);
309            }
310        }
311
312        // Masking key (4 random bytes)
313        let mask: [u8; 4] = rand::random();
314        frame.extend_from_slice(&mask);
315
316        // Masked payload
317        for (i, &byte) in payload.iter().enumerate() {
318            frame.push(byte ^ mask[i % 4]);
319        }
320
321        self.stream.write_all(&frame)
322            .map_err(|e| WebSocketError::new(format!("Failed to send frame: {}", e)))?;
323        self.stream.flush()
324            .map_err(|e| WebSocketError::new(format!("Failed to flush frame: {}", e)))?;
325
326        Ok(())
327    }
328
329    /// Receive a message
330    pub fn receive(&mut self) -> Result<Message, WebSocketError> {
331        loop {
332            let (opcode, payload) = self.receive_frame()?;
333
334            match opcode {
335                Opcode::Text => {
336                    let text = String::from_utf8(payload)
337                        .map_err(|e| WebSocketError::new(format!("Invalid UTF-8: {}", e)))?;
338                    return Ok(Message::Text(text));
339                }
340                Opcode::Binary => {
341                    return Ok(Message::Binary(payload));
342                }
343                Opcode::Close => {
344                    return Ok(Message::Close);
345                }
346                Opcode::Ping => {
347                    // Respond with pong
348                    self.send_frame(Opcode::Pong, &payload)?;
349                    // Continue to receive actual message
350                }
351                Opcode::Pong => {
352                    // Ignore pong, continue receiving
353                }
354                Opcode::Continuation => {
355                    // For simplicity, treat as text (full fragmentation support would need state)
356                    let text = String::from_utf8_lossy(&payload).to_string();
357                    return Ok(Message::Text(text));
358                }
359            }
360        }
361    }
362
363    /// Receive a WebSocket frame
364    fn receive_frame(&mut self) -> Result<(Opcode, Vec<u8>), WebSocketError> {
365        // Read first two bytes
366        let mut header = [0u8; 2];
367        self.read_exact(&mut header)?;
368
369        let _fin = (header[0] & 0x80) != 0;
370        let opcode = Opcode::from_u8(header[0] & 0x0F)
371            .ok_or_else(|| WebSocketError::new("Invalid opcode"))?;
372
373        let masked = (header[1] & 0x80) != 0;
374        let mut len = (header[1] & 0x7F) as usize;
375
376        // Extended payload length
377        if len == 126 {
378            let mut ext = [0u8; 2];
379            self.read_exact(&mut ext)?;
380            len = ((ext[0] as usize) << 8) | (ext[1] as usize);
381        } else if len == 127 {
382            let mut ext = [0u8; 8];
383            self.read_exact(&mut ext)?;
384            len = 0;
385            for &b in &ext {
386                len = (len << 8) | (b as usize);
387            }
388        }
389
390        // Read masking key if present (server frames usually aren't masked)
391        let mask = if masked {
392            let mut m = [0u8; 4];
393            self.read_exact(&mut m)?;
394            Some(m)
395        } else {
396            None
397        };
398
399        // Read payload
400        let mut payload = vec![0u8; len];
401        if len > 0 {
402            self.read_exact(&mut payload)?;
403        }
404
405        // Unmask if needed
406        if let Some(mask) = mask {
407            for (i, byte) in payload.iter_mut().enumerate() {
408                *byte ^= mask[i % 4];
409            }
410        }
411
412        Ok((opcode, payload))
413    }
414
415    /// Read exact number of bytes
416    fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), WebSocketError> {
417        let mut total = 0;
418        while total < buf.len() {
419            match self.stream.read(&mut buf[total..]) {
420                Ok(0) => return Err(WebSocketError::new("Connection closed")),
421                Ok(n) => total += n,
422                Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
423                Err(e) => return Err(WebSocketError::new(format!("Read error: {}", e))),
424            }
425        }
426        Ok(())
427    }
428
429    /// Close the WebSocket connection gracefully
430    pub fn close(&mut self) -> Result<(), WebSocketError> {
431        // Send close frame
432        let _ = self.send_close();
433        Ok(())
434    }
435}
436
437/// Connect to a WebSocket server, send a message, receive response, and close
438///
439/// This is a convenience function for simple request-response patterns
440pub fn send_and_receive(url: &str, message: &str) -> Result<String, WebSocketError> {
441    let mut ws = WebSocket::connect(url)?;
442    ws.send_text(message)?;
443
444    let response = match ws.receive()? {
445        Message::Text(t) => t,
446        Message::Binary(b) => String::from_utf8_lossy(&b).to_string(),
447        Message::Close => String::new(),
448        Message::Ping(_) | Message::Pong(_) => {
449            // Try to receive actual message after ping/pong
450            match ws.receive()? {
451                Message::Text(t) => t,
452                Message::Binary(b) => String::from_utf8_lossy(&b).to_string(),
453                _ => String::new(),
454            }
455        }
456    };
457
458    ws.close()?;
459    Ok(response)
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_parse_url() {
468        let (secure, host, port, path) = WebSocket::parse_url("ws://example.com/path").unwrap();
469        assert!(!secure);
470        assert_eq!(host, "example.com");
471        assert_eq!(port, 80);
472        assert_eq!(path, "/path");
473
474        let (secure, host, port, path) = WebSocket::parse_url("wss://example.com:8443/api").unwrap();
475        assert!(secure);
476        assert_eq!(host, "example.com");
477        assert_eq!(port, 8443);
478        assert_eq!(path, "/api");
479    }
480
481    #[test]
482    fn test_compute_accept_key() {
483        // Test vector from RFC 6455
484        let key = "dGhlIHNhbXBsZSBub25jZQ==";
485        let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
486        assert_eq!(WebSocket::compute_accept_key(key), expected);
487    }
488}