Skip to main content

syncable_ag_ui_server/transport/
ws.rs

1//! WebSocket Transport for AG-UI Events
2//!
3//! This module provides WebSocket transport for streaming AG-UI events to frontend clients.
4//! It integrates with axum to provide WebSocket endpoints as an alternative to SSE.
5//!
6//! # Architecture
7//!
8//! The WebSocket transport uses a channel-based design similar to SSE:
9//! - [`WsSender`] - Used by agent code to send events into the WebSocket stream
10//! - [`WsHandler`] - Handles the WebSocket connection and streams events
11//!
12//! # Example
13//!
14//! ```rust,ignore
15//! use ag_ui_server::transport::ws;
16//! use syncable_ag_ui_core::{Event, TextMessageStartEvent, MessageId};
17//! use axum::extract::ws::WebSocketUpgrade;
18//!
19//! async fn ws_endpoint(upgrade: WebSocketUpgrade) -> impl IntoResponse {
20//!     let (sender, handler) = ws::channel::<serde_json::Value>(32);
21//!
22//!     // Spawn task to send events
23//!     tokio::spawn(async move {
24//!         let event = Event::TextMessageStart(
25//!             TextMessageStartEvent::new(MessageId::random())
26//!         );
27//!         sender.send(event).await.ok();
28//!     });
29//!
30//!     handler.into_response(upgrade)
31//! }
32//! ```
33//!
34//! # SSE vs WebSocket
35//!
36//! Choose WebSocket when:
37//! - You need bidirectional communication (future AG-UI extensions)
38//! - You want lower latency for high-frequency updates
39//! - You need to work around SSE connection limits in browsers
40//!
41//! Choose SSE when:
42//! - You only need server-to-client streaming (current AG-UI)
43//! - You want automatic reconnection (built into EventSource)
44//! - You need HTTP/2 multiplexing benefits
45
46use std::time::Duration;
47
48use syncable_ag_ui_core::{AgentState, Event, JsonValue};
49use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
50use axum::response::IntoResponse;
51use futures::{SinkExt, StreamExt};
52use tokio::sync::mpsc;
53use tokio::time::interval;
54
55use crate::error::ServerError;
56
57/// Default ping interval for WebSocket keep-alive (30 seconds).
58pub const DEFAULT_PING_INTERVAL: Duration = Duration::from_secs(30);
59
60/// Error type for WebSocket send operations.
61#[derive(Debug, Clone)]
62pub struct SendError<T>(pub T);
63
64impl<T> std::fmt::Display for SendError<T> {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        write!(f, "WebSocket channel closed")
67    }
68}
69
70impl<T: std::fmt::Debug> std::error::Error for SendError<T> {}
71
72/// Configuration for WebSocket connections.
73#[derive(Debug, Clone)]
74pub struct WsConfig {
75    /// Interval between ping messages for keep-alive.
76    pub ping_interval: Duration,
77    /// Whether to send ping messages.
78    pub enable_ping: bool,
79}
80
81impl Default for WsConfig {
82    fn default() -> Self {
83        Self {
84            ping_interval: DEFAULT_PING_INTERVAL,
85            enable_ping: true,
86        }
87    }
88}
89
90impl WsConfig {
91    /// Creates a new configuration with default values.
92    pub fn new() -> Self {
93        Self::default()
94    }
95
96    /// Sets the ping interval.
97    pub fn ping_interval(mut self, interval: Duration) -> Self {
98        self.ping_interval = interval;
99        self
100    }
101
102    /// Disables ping messages.
103    pub fn disable_ping(mut self) -> Self {
104        self.enable_ping = false;
105        self
106    }
107}
108
109/// Sender side of a WebSocket channel.
110///
111/// Use this to send AG-UI events that will be streamed to connected clients.
112/// Events are serialized to JSON and sent as WebSocket text messages.
113#[derive(Debug, Clone)]
114pub struct WsSender<StateT: AgentState = JsonValue> {
115    sender: mpsc::Sender<Event<StateT>>,
116}
117
118impl<StateT: AgentState> WsSender<StateT> {
119    /// Sends an event to the WebSocket stream.
120    ///
121    /// Returns an error if the receiver has been dropped (client disconnected).
122    pub async fn send(&self, event: Event<StateT>) -> Result<(), SendError<Event<StateT>>> {
123        self.sender.send(event).await.map_err(|e| SendError(e.0))
124    }
125
126    /// Sends multiple events to the WebSocket stream.
127    ///
128    /// Stops and returns an error on the first failed send.
129    pub async fn send_many(
130        &self,
131        events: impl IntoIterator<Item = Event<StateT>>,
132    ) -> Result<(), SendError<Event<StateT>>> {
133        for event in events {
134            self.send(event).await?;
135        }
136        Ok(())
137    }
138
139    /// Tries to send an event without waiting.
140    ///
141    /// Returns an error if the channel is full or closed.
142    pub fn try_send(&self, event: Event<StateT>) -> Result<(), SendError<Event<StateT>>> {
143        self.sender
144            .try_send(event)
145            .map_err(|e| SendError(e.into_inner()))
146    }
147
148    /// Checks if the receiver is still connected.
149    pub fn is_closed(&self) -> bool {
150        self.sender.is_closed()
151    }
152}
153
154/// Handler side of a WebSocket channel.
155///
156/// This handles the WebSocket connection and streams events from the sender.
157pub struct WsHandler<StateT: AgentState = JsonValue> {
158    receiver: mpsc::Receiver<Event<StateT>>,
159    config: WsConfig,
160}
161
162impl<StateT: AgentState> WsHandler<StateT> {
163    /// Converts a WebSocket upgrade into an axum response.
164    ///
165    /// The response will upgrade to WebSocket and stream events as they are
166    /// sent through the corresponding [`WsSender`].
167    pub fn into_response(self, upgrade: WebSocketUpgrade) -> impl IntoResponse {
168        upgrade.on_upgrade(move |socket| self.handle_socket(socket))
169    }
170
171    /// Handles the WebSocket connection.
172    async fn handle_socket(self, socket: WebSocket) {
173        let (mut ws_sender, mut ws_receiver) = socket.split();
174        let mut event_receiver = self.receiver;
175
176        // Create ping interval if enabled
177        let mut ping_interval = if self.config.enable_ping {
178            Some(interval(self.config.ping_interval))
179        } else {
180            None
181        };
182
183        loop {
184            tokio::select! {
185                // Handle incoming events to send
186                event = event_receiver.recv() => {
187                    match event {
188                        Some(event) => {
189                            // Serialize event to JSON
190                            let json = match serde_json::to_string(&event) {
191                                Ok(json) => json,
192                                Err(e) => {
193                                    eprintln!("WebSocket serialization error: {}", e);
194                                    continue;
195                                }
196                            };
197
198                            // Send as text message
199                            if ws_sender.send(Message::Text(json.into())).await.is_err() {
200                                // Client disconnected
201                                break;
202                            }
203                        }
204                        None => {
205                            // Event channel closed, send close frame and exit
206                            let _ = ws_sender.send(Message::Close(None)).await;
207                            break;
208                        }
209                    }
210                }
211
212                // Handle ping interval
213                _ = async {
214                    if let Some(ref mut interval) = ping_interval {
215                        interval.tick().await;
216                    } else {
217                        // Never completes if ping disabled
218                        std::future::pending::<()>().await;
219                    }
220                } => {
221                    if ws_sender.send(Message::Ping(vec![].into())).await.is_err() {
222                        break;
223                    }
224                }
225
226                // Handle incoming WebSocket messages (for close/pong)
227                msg = ws_receiver.next() => {
228                    match msg {
229                        Some(Ok(Message::Pong(_))) => {
230                            // Pong received, connection is alive
231                        }
232                        Some(Ok(Message::Close(_))) | None => {
233                            // Client closed connection
234                            break;
235                        }
236                        Some(Ok(_)) => {
237                            // Ignore other message types (Text, Binary)
238                            // AG-UI is unidirectional server->client
239                        }
240                        Some(Err(_)) => {
241                            // WebSocket error
242                            break;
243                        }
244                    }
245                }
246            }
247        }
248    }
249}
250
251/// Creates a new WebSocket channel pair with default configuration.
252///
253/// The `buffer` parameter controls how many events can be queued before
254/// sends will block (or fail for `try_send`).
255///
256/// # Arguments
257///
258/// * `buffer` - The capacity of the internal channel buffer
259///
260/// # Returns
261///
262/// A tuple of (`WsSender`, `WsHandler`) that are connected.
263///
264/// # Example
265///
266/// ```rust,ignore
267/// let (sender, handler) = ws::channel::<serde_json::Value>(32);
268/// ```
269pub fn channel<StateT: AgentState>(buffer: usize) -> (WsSender<StateT>, WsHandler<StateT>) {
270    channel_with_config(buffer, WsConfig::default())
271}
272
273/// Creates a new WebSocket channel pair with custom configuration.
274///
275/// # Arguments
276///
277/// * `buffer` - The capacity of the internal channel buffer
278/// * `config` - WebSocket configuration options
279///
280/// # Returns
281///
282/// A tuple of (`WsSender`, `WsHandler`) that are connected.
283///
284/// # Example
285///
286/// ```rust,ignore
287/// let config = WsConfig::new()
288///     .ping_interval(Duration::from_secs(15))
289///     .disable_ping();
290/// let (sender, handler) = ws::channel_with_config::<serde_json::Value>(32, config);
291/// ```
292pub fn channel_with_config<StateT: AgentState>(
293    buffer: usize,
294    config: WsConfig,
295) -> (WsSender<StateT>, WsHandler<StateT>) {
296    let (tx, rx) = mpsc::channel(buffer);
297    (
298        WsSender { sender: tx },
299        WsHandler {
300            receiver: rx,
301            config,
302        },
303    )
304}
305
306/// Serializes an event to a WebSocket text message.
307///
308/// Returns the JSON string suitable for sending as a WebSocket text frame.
309pub fn format_ws_message<StateT: AgentState>(event: &Event<StateT>) -> Result<String, ServerError> {
310    serde_json::to_string(event).map_err(|e| ServerError::Serialization(e.to_string()))
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use syncable_ag_ui_core::{MessageId, RunErrorEvent, TextMessageContentEvent, TextMessageStartEvent};
317
318    #[tokio::test]
319    async fn test_channel_creation() {
320        let (sender, _handler) = channel::<JsonValue>(10);
321        assert!(!sender.is_closed());
322    }
323
324    #[tokio::test]
325    async fn test_channel_with_config() {
326        let config = WsConfig::new()
327            .ping_interval(Duration::from_secs(10))
328            .disable_ping();
329
330        let (sender, handler) = channel_with_config::<JsonValue>(10, config);
331        assert!(!sender.is_closed());
332        assert!(!handler.config.enable_ping);
333        assert_eq!(handler.config.ping_interval, Duration::from_secs(10));
334    }
335
336    #[tokio::test]
337    async fn test_send_event() {
338        let (sender, mut handler) = channel::<JsonValue>(10);
339
340        let event: Event = Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random()));
341
342        sender.send(event.clone()).await.unwrap();
343
344        // Receive from the handler's receiver directly for testing
345        let received = handler.receiver.recv().await.unwrap();
346        assert_eq!(received.event_type(), event.event_type());
347    }
348
349    #[tokio::test]
350    async fn test_send_many_events() {
351        let (sender, mut handler) = channel::<JsonValue>(10);
352
353        let events: Vec<Event> = vec![
354            Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())),
355            Event::TextMessageContent(TextMessageContentEvent::new_unchecked(
356                MessageId::random(),
357                "Hello",
358            )),
359            Event::RunError(RunErrorEvent::new("test error")),
360        ];
361
362        sender.send_many(events.clone()).await.unwrap();
363
364        // Verify all events received
365        for expected in &events {
366            let received = handler.receiver.recv().await.unwrap();
367            assert_eq!(received.event_type(), expected.event_type());
368        }
369    }
370
371    #[tokio::test]
372    async fn test_channel_close_detection() {
373        let (sender, handler) = channel::<JsonValue>(10);
374
375        // Drop the handler
376        drop(handler);
377
378        // Sender should detect closure
379        assert!(sender.is_closed());
380
381        // Send should fail
382        let event: Event = Event::RunError(RunErrorEvent::new("test"));
383        let result = sender.send(event).await;
384        assert!(result.is_err());
385    }
386
387    #[tokio::test]
388    async fn test_try_send() {
389        let (sender, _handler) = channel::<JsonValue>(2);
390
391        let event: Event = Event::RunError(RunErrorEvent::new("test"));
392
393        // First two should succeed (buffer size is 2)
394        assert!(sender.try_send(event.clone()).is_ok());
395        assert!(sender.try_send(event.clone()).is_ok());
396
397        // Third should fail (buffer full)
398        assert!(sender.try_send(event).is_err());
399    }
400
401    #[test]
402    fn test_format_ws_message() {
403        let event: Event = Event::RunError(RunErrorEvent::new("test error"));
404        let message = format_ws_message(&event).unwrap();
405
406        assert!(message.contains("\"type\":\"RUN_ERROR\""));
407        assert!(message.contains("\"message\":\"test error\""));
408    }
409
410    #[test]
411    fn test_format_ws_message_complex() {
412        let event: Event =
413            Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random()));
414        let message = format_ws_message(&event).unwrap();
415
416        assert!(message.contains("\"type\":\"TEXT_MESSAGE_START\""));
417        assert!(message.contains("\"messageId\":"));
418        assert!(message.contains("\"role\":\"assistant\""));
419    }
420
421    #[test]
422    fn test_ws_config_default() {
423        let config = WsConfig::default();
424        assert!(config.enable_ping);
425        assert_eq!(config.ping_interval, DEFAULT_PING_INTERVAL);
426    }
427
428    #[test]
429    fn test_ws_config_builder() {
430        let config = WsConfig::new()
431            .ping_interval(Duration::from_secs(60))
432            .disable_ping();
433
434        assert!(!config.enable_ping);
435        assert_eq!(config.ping_interval, Duration::from_secs(60));
436    }
437
438    #[test]
439    fn test_send_error_display() {
440        let error: SendError<i32> = SendError(42);
441        assert_eq!(format!("{}", error), "WebSocket channel closed");
442    }
443}