Skip to main content

sigil_parser/protocol/
websocket.rs

1//! WebSocket Support
2//!
3//! Bidirectional real-time communication over WebSocket protocol.
4//!
5//! ## Features
6//!
7//! - Full WebSocket protocol support (RFC 6455)
8//! - Text and binary messages
9//! - Ping/pong handling
10//! - Automatic reconnection
11//! - TLS/WSS support
12//! - Subprotocol negotiation
13
14use super::common::{Headers, ProtocolError, ProtocolResult, Timeout, Uri};
15use std::time::Duration;
16
17/// WebSocket message types
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum Message {
20    /// UTF-8 text message
21    Text(String),
22    /// Binary message
23    Binary(Vec<u8>),
24    /// Ping message (client should respond with Pong)
25    Ping(Vec<u8>),
26    /// Pong message (response to Ping)
27    Pong(Vec<u8>),
28    /// Close message with optional code and reason
29    Close(Option<CloseFrame>),
30}
31
32impl Message {
33    /// Create a text message
34    pub fn text(s: impl Into<String>) -> Self {
35        Message::Text(s.into())
36    }
37
38    /// Create a binary message
39    pub fn binary(data: impl Into<Vec<u8>>) -> Self {
40        Message::Binary(data.into())
41    }
42
43    /// Create a ping message
44    pub fn ping(data: impl Into<Vec<u8>>) -> Self {
45        Message::Ping(data.into())
46    }
47
48    /// Create a pong message
49    pub fn pong(data: impl Into<Vec<u8>>) -> Self {
50        Message::Pong(data.into())
51    }
52
53    /// Create a close message
54    pub fn close(code: CloseCode, reason: impl Into<String>) -> Self {
55        Message::Close(Some(CloseFrame {
56            code,
57            reason: reason.into(),
58        }))
59    }
60
61    /// Check if this is a text message
62    pub fn is_text(&self) -> bool {
63        matches!(self, Message::Text(_))
64    }
65
66    /// Check if this is a binary message
67    pub fn is_binary(&self) -> bool {
68        matches!(self, Message::Binary(_))
69    }
70
71    /// Check if this is a ping message
72    pub fn is_ping(&self) -> bool {
73        matches!(self, Message::Ping(_))
74    }
75
76    /// Check if this is a pong message
77    pub fn is_pong(&self) -> bool {
78        matches!(self, Message::Pong(_))
79    }
80
81    /// Check if this is a close message
82    pub fn is_close(&self) -> bool {
83        matches!(self, Message::Close(_))
84    }
85
86    /// Check if this is a data message (text or binary)
87    pub fn is_data(&self) -> bool {
88        matches!(self, Message::Text(_) | Message::Binary(_))
89    }
90
91    /// Get the text content if this is a text message
92    pub fn as_text(&self) -> Option<&str> {
93        match self {
94            Message::Text(s) => Some(s),
95            _ => None,
96        }
97    }
98
99    /// Get the binary content if this is a binary message
100    pub fn as_binary(&self) -> Option<&[u8]> {
101        match self {
102            Message::Binary(b) => Some(b),
103            _ => None,
104        }
105    }
106
107    /// Convert to text (consumes the message)
108    pub fn into_text(self) -> Option<String> {
109        match self {
110            Message::Text(s) => Some(s),
111            _ => None,
112        }
113    }
114
115    /// Convert to bytes (consumes the message)
116    pub fn into_bytes(self) -> Option<Vec<u8>> {
117        match self {
118            Message::Binary(b) => Some(b),
119            _ => None,
120        }
121    }
122
123    /// Get the payload length
124    pub fn len(&self) -> usize {
125        match self {
126            Message::Text(s) => s.len(),
127            Message::Binary(b) | Message::Ping(b) | Message::Pong(b) => b.len(),
128            Message::Close(Some(frame)) => 2 + frame.reason.len(),
129            Message::Close(None) => 0,
130        }
131    }
132
133    /// Check if the payload is empty
134    pub fn is_empty(&self) -> bool {
135        self.len() == 0
136    }
137}
138
139/// Close frame with status code and reason
140#[derive(Debug, Clone, PartialEq, Eq)]
141pub struct CloseFrame {
142    /// Close status code
143    pub code: CloseCode,
144    /// Close reason (UTF-8 text)
145    pub reason: String,
146}
147
148/// WebSocket close codes (RFC 6455)
149#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
150#[repr(u16)]
151pub enum CloseCode {
152    /// Normal closure
153    Normal = 1000,
154    /// Endpoint going away
155    GoingAway = 1001,
156    /// Protocol error
157    Protocol = 1002,
158    /// Unsupported data type
159    Unsupported = 1003,
160    /// No status received (reserved, never sent)
161    NoStatus = 1005,
162    /// Abnormal closure (reserved, never sent)
163    Abnormal = 1006,
164    /// Invalid payload data
165    InvalidData = 1007,
166    /// Policy violation
167    Policy = 1008,
168    /// Message too big
169    MessageTooBig = 1009,
170    /// Missing extension
171    MissingExtension = 1010,
172    /// Internal server error
173    InternalError = 1011,
174    /// TLS handshake failure (reserved, never sent)
175    TlsFailure = 1015,
176    /// Application-specific code
177    Custom(u16),
178}
179
180impl CloseCode {
181    /// Create from u16
182    pub fn from_u16(code: u16) -> Self {
183        match code {
184            1000 => CloseCode::Normal,
185            1001 => CloseCode::GoingAway,
186            1002 => CloseCode::Protocol,
187            1003 => CloseCode::Unsupported,
188            1005 => CloseCode::NoStatus,
189            1006 => CloseCode::Abnormal,
190            1007 => CloseCode::InvalidData,
191            1008 => CloseCode::Policy,
192            1009 => CloseCode::MessageTooBig,
193            1010 => CloseCode::MissingExtension,
194            1011 => CloseCode::InternalError,
195            1015 => CloseCode::TlsFailure,
196            _ => CloseCode::Custom(code),
197        }
198    }
199
200    /// Get the u16 code value
201    pub fn as_u16(&self) -> u16 {
202        match self {
203            CloseCode::Normal => 1000,
204            CloseCode::GoingAway => 1001,
205            CloseCode::Protocol => 1002,
206            CloseCode::Unsupported => 1003,
207            CloseCode::NoStatus => 1005,
208            CloseCode::Abnormal => 1006,
209            CloseCode::InvalidData => 1007,
210            CloseCode::Policy => 1008,
211            CloseCode::MessageTooBig => 1009,
212            CloseCode::MissingExtension => 1010,
213            CloseCode::InternalError => 1011,
214            CloseCode::TlsFailure => 1015,
215            CloseCode::Custom(code) => *code,
216        }
217    }
218
219    /// Get the close code description
220    pub fn description(&self) -> &'static str {
221        match self {
222            CloseCode::Normal => "Normal closure",
223            CloseCode::GoingAway => "Endpoint going away",
224            CloseCode::Protocol => "Protocol error",
225            CloseCode::Unsupported => "Unsupported data type",
226            CloseCode::NoStatus => "No status received",
227            CloseCode::Abnormal => "Abnormal closure",
228            CloseCode::InvalidData => "Invalid payload data",
229            CloseCode::Policy => "Policy violation",
230            CloseCode::MessageTooBig => "Message too big",
231            CloseCode::MissingExtension => "Missing extension",
232            CloseCode::InternalError => "Internal server error",
233            CloseCode::TlsFailure => "TLS handshake failure",
234            CloseCode::Custom(_) => "Custom close code",
235        }
236    }
237}
238
239impl std::fmt::Display for CloseCode {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        write!(f, "{} ({})", self.description(), self.as_u16())
242    }
243}
244
245/// WebSocket connection configuration
246#[derive(Debug, Clone)]
247pub struct WebSocketConfig {
248    /// Maximum message size
249    pub max_message_size: usize,
250    /// Maximum frame size
251    pub max_frame_size: usize,
252    /// Whether to accept unmasked frames from clients
253    pub accept_unmasked_frames: bool,
254    /// Subprotocols to request
255    pub subprotocols: Vec<String>,
256    /// Additional headers to send with the upgrade request
257    pub headers: Headers,
258    /// Connection timeout
259    pub connect_timeout: Duration,
260    /// Ping interval for keep-alive
261    pub ping_interval: Option<Duration>,
262    /// Pong timeout (close connection if no pong received)
263    pub pong_timeout: Option<Duration>,
264}
265
266impl Default for WebSocketConfig {
267    fn default() -> Self {
268        WebSocketConfig {
269            max_message_size: 64 * 1024 * 1024, // 64MB
270            max_frame_size: 16 * 1024 * 1024,   // 16MB
271            accept_unmasked_frames: false,
272            subprotocols: Vec::new(),
273            headers: Headers::new(),
274            connect_timeout: Duration::from_secs(30),
275            ping_interval: Some(Duration::from_secs(30)),
276            pong_timeout: Some(Duration::from_secs(10)),
277        }
278    }
279}
280
281/// WebSocket connection state
282#[derive(Debug, Clone, Copy, PartialEq, Eq)]
283pub enum ConnectionState {
284    /// Connection is being established
285    Connecting,
286    /// Connection is open and ready
287    Open,
288    /// Connection is closing
289    Closing,
290    /// Connection is closed
291    Closed,
292}
293
294/// WebSocket connection
295#[derive(Debug)]
296pub struct WebSocket {
297    /// The WebSocket URL
298    url: String,
299    /// Connection configuration
300    config: WebSocketConfig,
301    /// Current connection state
302    state: ConnectionState,
303    /// Negotiated subprotocol
304    subprotocol: Option<String>,
305    #[cfg(feature = "tokio-tungstenite")]
306    inner: Option<
307        tokio_tungstenite::WebSocketStream<
308            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
309        >,
310    >,
311}
312
313impl WebSocket {
314    /// Create a connection builder
315    pub fn builder(url: impl Into<String>) -> WebSocketBuilder {
316        WebSocketBuilder {
317            url: url.into(),
318            config: WebSocketConfig::default(),
319        }
320    }
321
322    /// Connect to a WebSocket server
323    #[cfg(feature = "tokio-tungstenite")]
324    pub async fn connect(url: impl Into<String>) -> ProtocolResult<Self> {
325        Self::builder(url).connect().await
326    }
327
328    /// Connect to a WebSocket server (stub for non-tungstenite builds)
329    #[cfg(not(feature = "tokio-tungstenite"))]
330    pub async fn connect(url: impl Into<String>) -> ProtocolResult<Self> {
331        let _ = url;
332        Err(ProtocolError::Protocol(
333            "WebSocket requires 'websocket' feature".to_string(),
334        ))
335    }
336
337    /// Get the connection URL
338    pub fn url(&self) -> &str {
339        &self.url
340    }
341
342    /// Get the current connection state
343    pub fn state(&self) -> ConnectionState {
344        self.state
345    }
346
347    /// Check if the connection is open
348    pub fn is_open(&self) -> bool {
349        self.state == ConnectionState::Open
350    }
351
352    /// Get the negotiated subprotocol
353    pub fn subprotocol(&self) -> Option<&str> {
354        self.subprotocol.as_deref()
355    }
356
357    /// Send a message
358    #[cfg(feature = "tokio-tungstenite")]
359    pub async fn send(&mut self, message: Message) -> ProtocolResult<()> {
360        use futures_util::SinkExt;
361        use tokio_tungstenite::tungstenite::Message as TMessage;
362
363        if self.state != ConnectionState::Open {
364            return Err(ProtocolError::ChannelClosed);
365        }
366
367        let msg = match message {
368            Message::Text(s) => TMessage::Text(s),
369            Message::Binary(b) => TMessage::Binary(b),
370            Message::Ping(b) => TMessage::Ping(b),
371            Message::Pong(b) => TMessage::Pong(b),
372            Message::Close(frame) => {
373                let close_frame = frame.map(|f| {
374                    tokio_tungstenite::tungstenite::protocol::CloseFrame {
375                        code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::from(f.code.as_u16()),
376                        reason: f.reason.into(),
377                    }
378                });
379                TMessage::Close(close_frame)
380            }
381        };
382
383        if let Some(ref mut inner) = self.inner {
384            inner
385                .send(msg)
386                .await
387                .map_err(|e| ProtocolError::Protocol(e.to_string()))?;
388        }
389
390        Ok(())
391    }
392
393    /// Send a message (stub for non-tungstenite builds)
394    #[cfg(not(feature = "tokio-tungstenite"))]
395    pub async fn send(&mut self, message: Message) -> ProtocolResult<()> {
396        let _ = message;
397        Err(ProtocolError::Protocol(
398            "WebSocket requires 'websocket' feature".to_string(),
399        ))
400    }
401
402    /// Receive a message
403    #[cfg(feature = "tokio-tungstenite")]
404    pub async fn recv(&mut self) -> ProtocolResult<Option<Message>> {
405        use futures_util::StreamExt;
406        use tokio_tungstenite::tungstenite::Message as TMessage;
407
408        if self.state != ConnectionState::Open {
409            return Ok(None);
410        }
411
412        if let Some(ref mut inner) = self.inner {
413            // Use loop instead of recursion to handle Frame messages
414            loop {
415                match inner.next().await {
416                    Some(Ok(msg)) => {
417                        let message = match msg {
418                            TMessage::Text(s) => Message::Text(s),
419                            TMessage::Binary(b) => Message::Binary(b),
420                            TMessage::Ping(b) => Message::Ping(b),
421                            TMessage::Pong(b) => Message::Pong(b),
422                            TMessage::Close(frame) => {
423                                self.state = ConnectionState::Closed;
424                                Message::Close(frame.map(|f| CloseFrame {
425                                    code: CloseCode::from_u16(f.code.into()),
426                                    reason: f.reason.to_string(),
427                                }))
428                            }
429                            TMessage::Frame(_) => continue, // Skip raw frames, get next message
430                        };
431                        return Ok(Some(message));
432                    }
433                    Some(Err(e)) => {
434                        self.state = ConnectionState::Closed;
435                        return Err(ProtocolError::Protocol(e.to_string()));
436                    }
437                    None => {
438                        self.state = ConnectionState::Closed;
439                        return Ok(None);
440                    }
441                }
442            }
443        } else {
444            Ok(None)
445        }
446    }
447
448    /// Receive a message (stub for non-tungstenite builds)
449    #[cfg(not(feature = "tokio-tungstenite"))]
450    pub async fn recv(&mut self) -> ProtocolResult<Option<Message>> {
451        Err(ProtocolError::Protocol(
452            "WebSocket requires 'websocket' feature".to_string(),
453        ))
454    }
455
456    /// Send a text message
457    pub async fn send_text(&mut self, text: impl Into<String>) -> ProtocolResult<()> {
458        self.send(Message::Text(text.into())).await
459    }
460
461    /// Send a binary message
462    pub async fn send_binary(&mut self, data: impl Into<Vec<u8>>) -> ProtocolResult<()> {
463        self.send(Message::Binary(data.into())).await
464    }
465
466    /// Send a ping
467    pub async fn ping(&mut self, data: impl Into<Vec<u8>>) -> ProtocolResult<()> {
468        self.send(Message::Ping(data.into())).await
469    }
470
471    /// Close the connection
472    pub async fn close(
473        &mut self,
474        code: CloseCode,
475        reason: impl Into<String>,
476    ) -> ProtocolResult<()> {
477        if self.state != ConnectionState::Open {
478            return Ok(());
479        }
480
481        self.state = ConnectionState::Closing;
482        self.send(Message::close(code, reason)).await?;
483        self.state = ConnectionState::Closed;
484        Ok(())
485    }
486
487    /// Close with normal status
488    pub async fn close_normal(&mut self) -> ProtocolResult<()> {
489        self.close(CloseCode::Normal, "").await
490    }
491}
492
493/// Builder for WebSocket connections
494#[derive(Debug, Clone)]
495pub struct WebSocketBuilder {
496    url: String,
497    config: WebSocketConfig,
498}
499
500impl WebSocketBuilder {
501    /// Set maximum message size
502    pub fn max_message_size(mut self, size: usize) -> Self {
503        self.config.max_message_size = size;
504        self
505    }
506
507    /// Set maximum frame size
508    pub fn max_frame_size(mut self, size: usize) -> Self {
509        self.config.max_frame_size = size;
510        self
511    }
512
513    /// Add a subprotocol to request
514    pub fn subprotocol(mut self, protocol: impl Into<String>) -> Self {
515        self.config.subprotocols.push(protocol.into());
516        self
517    }
518
519    /// Add subprotocols to request
520    pub fn subprotocols(mut self, protocols: Vec<String>) -> Self {
521        self.config.subprotocols.extend(protocols);
522        self
523    }
524
525    /// Add a header to the upgrade request
526    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
527        self.config.headers.insert(key, value);
528        self
529    }
530
531    /// Set connection timeout
532    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
533        self.config.connect_timeout = timeout;
534        self
535    }
536
537    /// Set ping interval for keep-alive
538    pub fn ping_interval(mut self, interval: Duration) -> Self {
539        self.config.ping_interval = Some(interval);
540        self
541    }
542
543    /// Disable ping keep-alive
544    pub fn no_ping(mut self) -> Self {
545        self.config.ping_interval = None;
546        self
547    }
548
549    /// Set bearer authentication
550    pub fn bearer_auth(self, token: impl Into<String>) -> Self {
551        self.header("Authorization", format!("Bearer {}", token.into()))
552    }
553
554    /// Connect to the WebSocket server
555    #[cfg(feature = "tokio-tungstenite")]
556    pub async fn connect(self) -> ProtocolResult<WebSocket> {
557        use tokio_tungstenite::connect_async;
558
559        let (ws_stream, _response) = connect_async(&self.url)
560            .await
561            .map_err(|e| ProtocolError::ConnectionFailed(e.to_string()))?;
562
563        Ok(WebSocket {
564            url: self.url,
565            config: self.config,
566            state: ConnectionState::Open,
567            subprotocol: None, // TODO: extract from response
568            inner: Some(ws_stream),
569        })
570    }
571
572    /// Connect to the WebSocket server (stub for non-tungstenite builds)
573    #[cfg(not(feature = "tokio-tungstenite"))]
574    pub async fn connect(self) -> ProtocolResult<WebSocket> {
575        Err(ProtocolError::Protocol(
576            "WebSocket requires 'websocket' feature".to_string(),
577        ))
578    }
579}
580
581/// Reconnecting WebSocket wrapper
582#[derive(Debug)]
583pub struct ReconnectingWebSocket {
584    /// Builder used to create connections
585    builder: WebSocketBuilder,
586    /// Current connection
587    connection: Option<WebSocket>,
588    /// Reconnect configuration
589    reconnect_config: ReconnectConfig,
590    /// Number of reconnection attempts
591    attempt_count: u32,
592}
593
594/// Reconnection configuration
595#[derive(Debug, Clone)]
596pub struct ReconnectConfig {
597    /// Initial delay before reconnecting
598    pub initial_delay: Duration,
599    /// Maximum delay between reconnection attempts
600    pub max_delay: Duration,
601    /// Delay multiplier for exponential backoff
602    pub multiplier: f64,
603    /// Maximum number of reconnection attempts (None = infinite)
604    pub max_attempts: Option<u32>,
605}
606
607impl Default for ReconnectConfig {
608    fn default() -> Self {
609        ReconnectConfig {
610            initial_delay: Duration::from_secs(1),
611            max_delay: Duration::from_secs(30),
612            multiplier: 2.0,
613            max_attempts: None,
614        }
615    }
616}
617
618impl ReconnectingWebSocket {
619    /// Create a new reconnecting WebSocket
620    pub fn new(url: impl Into<String>) -> Self {
621        ReconnectingWebSocket {
622            builder: WebSocket::builder(url),
623            connection: None,
624            reconnect_config: ReconnectConfig::default(),
625            attempt_count: 0,
626        }
627    }
628
629    /// Configure reconnection
630    pub fn reconnect_config(mut self, config: ReconnectConfig) -> Self {
631        self.reconnect_config = config;
632        self
633    }
634
635    /// Set initial reconnect delay
636    pub fn initial_delay(mut self, delay: Duration) -> Self {
637        self.reconnect_config.initial_delay = delay;
638        self
639    }
640
641    /// Set maximum reconnect delay
642    pub fn max_delay(mut self, delay: Duration) -> Self {
643        self.reconnect_config.max_delay = delay;
644        self
645    }
646
647    /// Set maximum reconnection attempts
648    pub fn max_attempts(mut self, attempts: u32) -> Self {
649        self.reconnect_config.max_attempts = Some(attempts);
650        self
651    }
652
653    /// Connect (or reconnect) to the server
654    pub async fn connect(&mut self) -> ProtocolResult<()> {
655        self.connection = Some(self.builder.clone().connect().await?);
656        self.attempt_count = 0;
657        Ok(())
658    }
659
660    /// Check if connected
661    pub fn is_connected(&self) -> bool {
662        self.connection
663            .as_ref()
664            .map(|c| c.is_open())
665            .unwrap_or(false)
666    }
667
668    /// Get the current connection
669    pub fn connection(&mut self) -> Option<&mut WebSocket> {
670        self.connection.as_mut()
671    }
672
673    /// Calculate the delay for the next reconnection attempt
674    fn next_delay(&self) -> Duration {
675        let delay = self.reconnect_config.initial_delay.mul_f64(
676            self.reconnect_config
677                .multiplier
678                .powi(self.attempt_count as i32),
679        );
680        delay.min(self.reconnect_config.max_delay)
681    }
682
683    /// Attempt to reconnect
684    pub async fn reconnect(&mut self) -> ProtocolResult<()> {
685        if let Some(max) = self.reconnect_config.max_attempts {
686            if self.attempt_count >= max {
687                return Err(ProtocolError::ConnectionFailed(format!(
688                    "Max reconnection attempts ({}) exceeded",
689                    max
690                )));
691            }
692        }
693
694        let delay = self.next_delay();
695        self.attempt_count += 1;
696
697        #[cfg(feature = "tokio")]
698        tokio::time::sleep(delay).await;
699
700        self.connection = Some(self.builder.clone().connect().await?);
701        self.attempt_count = 0;
702        Ok(())
703    }
704}
705
706#[cfg(test)]
707mod tests {
708    use super::*;
709
710    #[test]
711    fn test_message_types() {
712        let text = Message::text("hello");
713        assert!(text.is_text());
714        assert!(text.is_data());
715        assert_eq!(text.as_text(), Some("hello"));
716
717        let binary = Message::binary(vec![1, 2, 3]);
718        assert!(binary.is_binary());
719        assert!(binary.is_data());
720        assert_eq!(binary.as_binary(), Some(&[1u8, 2, 3][..]));
721
722        let ping = Message::ping(vec![1, 2]);
723        assert!(ping.is_ping());
724        assert!(!ping.is_data());
725    }
726
727    #[test]
728    fn test_close_codes() {
729        assert_eq!(CloseCode::Normal.as_u16(), 1000);
730        assert_eq!(CloseCode::from_u16(1000), CloseCode::Normal);
731        assert_eq!(CloseCode::from_u16(4000), CloseCode::Custom(4000));
732    }
733
734    #[test]
735    fn test_websocket_builder() {
736        let builder = WebSocket::builder("wss://example.com/socket")
737            .subprotocol("graphql-transport-ws")
738            .bearer_auth("token123")
739            .connect_timeout(Duration::from_secs(10));
740
741        assert_eq!(builder.url, "wss://example.com/socket");
742        assert!(builder
743            .config
744            .subprotocols
745            .contains(&"graphql-transport-ws".to_string()));
746    }
747
748    #[test]
749    fn test_reconnect_config() {
750        let config = ReconnectConfig {
751            initial_delay: Duration::from_secs(1),
752            max_delay: Duration::from_secs(30),
753            multiplier: 2.0,
754            max_attempts: Some(5),
755        };
756
757        let ws = ReconnectingWebSocket::new("wss://example.com").reconnect_config(config);
758
759        assert_eq!(ws.reconnect_config.max_attempts, Some(5));
760    }
761}