Skip to main content

syncable_ag_ui_client/
ws.rs

1//! WebSocket Client
2//!
3//! This module provides a client for consuming AG-UI events via WebSocket.
4//!
5//! # Example
6//!
7//! ```rust,ignore
8//! use ag_ui_client::WsClient;
9//! use futures::StreamExt;
10//!
11//! let client = WsClient::connect("ws://localhost:3000/ws").await?;
12//! let mut stream = client.into_stream();
13//!
14//! while let Some(event) = stream.next().await {
15//!     println!("Event: {:?}", event?.event_type());
16//! }
17//! ```
18
19use std::pin::Pin;
20use std::task::{Context, Poll};
21
22use syncable_ag_ui_core::{Event, JsonValue};
23use futures::{SinkExt, Stream};
24use tokio_tungstenite::{
25    connect_async,
26    tungstenite::{self, Message},
27    MaybeTlsStream, WebSocketStream,
28};
29
30use crate::error::{ClientError, Result};
31
32/// Configuration for WebSocket client connections.
33#[derive(Debug, Clone)]
34pub struct WsConfig {
35    /// Custom headers to include in the upgrade request.
36    pub headers: Vec<(String, String)>,
37    /// Whether to automatically respond to ping messages.
38    pub auto_pong: bool,
39}
40
41impl Default for WsConfig {
42    fn default() -> Self {
43        Self {
44            headers: Vec::new(),
45            auto_pong: true,
46        }
47    }
48}
49
50impl WsConfig {
51    /// Creates a new configuration with default values.
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    /// Adds a custom header.
57    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
58        self.headers.push((name.into(), value.into()));
59        self
60    }
61
62    /// Adds an authorization bearer token.
63    pub fn bearer_token(self, token: impl Into<String>) -> Self {
64        self.header("Authorization", format!("Bearer {}", token.into()))
65    }
66
67    /// Disables automatic pong responses.
68    pub fn disable_auto_pong(mut self) -> Self {
69        self.auto_pong = false;
70        self
71    }
72}
73
74/// WebSocket client for consuming AG-UI event streams.
75///
76/// The client connects to a WebSocket endpoint and provides a stream of
77/// parsed AG-UI events.
78pub struct WsClient {
79    socket: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
80    auto_pong: bool,
81}
82
83impl WsClient {
84    /// Connects to a WebSocket endpoint with default configuration.
85    ///
86    /// # Arguments
87    ///
88    /// * `url` - The WebSocket endpoint URL (ws:// or wss://)
89    ///
90    /// # Example
91    ///
92    /// ```rust,ignore
93    /// let client = WsClient::connect("ws://localhost:3000/ws").await?;
94    /// ```
95    pub async fn connect(url: &str) -> Result<Self> {
96        Self::connect_with_config(url, WsConfig::default()).await
97    }
98
99    /// Connects to a WebSocket endpoint with custom configuration.
100    ///
101    /// # Arguments
102    ///
103    /// * `url` - The WebSocket endpoint URL (ws:// or wss://)
104    /// * `config` - Connection configuration
105    ///
106    /// # Example
107    ///
108    /// ```rust,ignore
109    /// let config = WsConfig::new()
110    ///     .bearer_token("my-token");
111    /// let client = WsClient::connect_with_config("ws://localhost:3000/ws", config).await?;
112    /// ```
113    pub async fn connect_with_config(url: &str, config: WsConfig) -> Result<Self> {
114        // Build the request with custom headers
115        let mut request = tungstenite::http::Request::builder()
116            .uri(url)
117            .header("Host", extract_host(url)?)
118            .header("Connection", "Upgrade")
119            .header("Upgrade", "websocket")
120            .header("Sec-WebSocket-Version", "13")
121            .header(
122                "Sec-WebSocket-Key",
123                tungstenite::handshake::client::generate_key(),
124            );
125
126        for (name, value) in config.headers {
127            request = request.header(name, value);
128        }
129
130        let request = request
131            .body(())
132            .map_err(|e| ClientError::connection(e.to_string()))?;
133
134        let (socket, _response) = connect_async(request)
135            .await
136            .map_err(|e| ClientError::connection(e.to_string()))?;
137
138        Ok(Self {
139            socket,
140            auto_pong: config.auto_pong,
141        })
142    }
143
144    /// Converts this client into an event stream.
145    ///
146    /// The stream yields parsed AG-UI events as they arrive.
147    pub fn into_stream(self) -> WsEventStream {
148        WsEventStream {
149            socket: self.socket,
150            auto_pong: self.auto_pong,
151        }
152    }
153
154    /// Closes the WebSocket connection gracefully.
155    pub async fn close(mut self) -> Result<()> {
156        self.socket
157            .close(None)
158            .await
159            .map_err(|e| ClientError::connection(e.to_string()))
160    }
161}
162
163/// A stream of AG-UI events from a WebSocket connection.
164///
165/// This stream yields `Result<Event>` items as events arrive from the server.
166pub struct WsEventStream {
167    socket: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
168    auto_pong: bool,
169}
170
171impl Stream for WsEventStream {
172    type Item = Result<Event<JsonValue>>;
173
174    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
175        loop {
176            match Pin::new(&mut self.socket).poll_next(cx) {
177                Poll::Ready(Some(Ok(msg))) => {
178                    match msg {
179                        Message::Text(text) => {
180                            // Parse the event data as JSON
181                            match serde_json::from_str::<Event<JsonValue>>(&text) {
182                                Ok(event) => return Poll::Ready(Some(Ok(event))),
183                                Err(e) => {
184                                    return Poll::Ready(Some(Err(ClientError::parse(format!(
185                                        "failed to parse event: {}",
186                                        e
187                                    )))))
188                                }
189                            }
190                        }
191                        Message::Ping(data) => {
192                            if self.auto_pong {
193                                // Send pong response
194                                let mut socket = Pin::new(&mut self.socket);
195                                let _ = socket.start_send_unpin(Message::Pong(data));
196                            }
197                            continue;
198                        }
199                        Message::Pong(_) => {
200                            // Ignore pong messages
201                            continue;
202                        }
203                        Message::Close(_) => {
204                            return Poll::Ready(None);
205                        }
206                        Message::Binary(_) | Message::Frame(_) => {
207                            // Ignore binary/frame messages for AG-UI
208                            continue;
209                        }
210                    }
211                }
212                Poll::Ready(Some(Err(e))) => {
213                    return Poll::Ready(Some(Err(ClientError::WebSocket(e))))
214                }
215                Poll::Ready(None) => return Poll::Ready(None),
216                Poll::Pending => return Poll::Pending,
217            }
218        }
219    }
220}
221
222/// Extracts the host from a URL for the Host header.
223fn extract_host(url: &str) -> Result<String> {
224    let url = url::Url::parse(url).map_err(|e| ClientError::InvalidUrl(e.to_string()))?;
225
226    let host = url
227        .host_str()
228        .ok_or_else(|| ClientError::InvalidUrl("missing host".to_string()))?;
229
230    match url.port() {
231        Some(port) => Ok(format!("{}:{}", host, port)),
232        None => Ok(host.to_string()),
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_ws_config_default() {
242        let config = WsConfig::default();
243        assert!(config.headers.is_empty());
244        assert!(config.auto_pong);
245    }
246
247    #[test]
248    fn test_ws_config_builder() {
249        let config = WsConfig::new()
250            .header("X-Custom", "value")
251            .bearer_token("token123")
252            .disable_auto_pong();
253
254        assert_eq!(config.headers.len(), 2);
255        assert_eq!(config.headers[0], ("X-Custom".to_string(), "value".to_string()));
256        assert_eq!(
257            config.headers[1],
258            ("Authorization".to_string(), "Bearer token123".to_string())
259        );
260        assert!(!config.auto_pong);
261    }
262
263    #[test]
264    fn test_extract_host() {
265        assert_eq!(extract_host("ws://localhost:3000/ws").unwrap(), "localhost:3000");
266        assert_eq!(extract_host("wss://example.com/events").unwrap(), "example.com");
267        assert_eq!(
268            extract_host("ws://api.example.com:8080/stream").unwrap(),
269            "api.example.com:8080"
270        );
271    }
272
273    #[test]
274    fn test_extract_host_invalid() {
275        assert!(extract_host("not a url").is_err());
276    }
277}