Skip to main content

syncable_ag_ui_server/transport/
sse.rs

1//! Server-Sent Events (SSE) Transport
2//!
3//! This module provides SSE transport for streaming AG-UI events to frontend clients.
4//! It integrates with axum to provide HTTP SSE endpoints.
5//!
6//! # Architecture
7//!
8//! The SSE transport uses a channel-based design:
9//! - [`SseSender`] - Used by agent code to send events into the stream
10//! - [`SseHandler`] - Converted to an axum SSE response for the HTTP endpoint
11//!
12//! # Example
13//!
14//! ```rust,ignore
15//! use ag_ui_server::transport::sse;
16//! use syncable_ag_ui_core::{Event, TextMessageStartEvent, MessageId};
17//!
18//! // Create a channel pair
19//! let (sender, handler) = sse::channel::<serde_json::Value>(32);
20//!
21//! // In your axum handler, return the SSE response
22//! async fn events_endpoint() -> impl IntoResponse {
23//!     let (sender, handler) = sse::channel::<serde_json::Value>(32);
24//!
25//!     // Spawn task to send events
26//!     tokio::spawn(async move {
27//!         let event = Event::TextMessageStart(
28//!             TextMessageStartEvent::new(MessageId::random())
29//!         );
30//!         sender.send(event).await.ok();
31//!     });
32//!
33//!     handler.into_response()
34//! }
35//! ```
36
37use std::convert::Infallible;
38use std::pin::Pin;
39use std::task::{Context, Poll};
40
41use syncable_ag_ui_core::{AgentState, Event, JsonValue};
42use axum::response::sse::{Event as AxumSseEvent, KeepAlive, Sse};
43use axum::response::IntoResponse;
44use futures::Stream;
45use tokio::sync::mpsc;
46use tokio_stream::wrappers::ReceiverStream;
47
48use crate::error::ServerError;
49
50/// Error type for SSE send operations.
51#[derive(Debug, Clone)]
52pub struct SendError<T>(pub T);
53
54impl<T> std::fmt::Display for SendError<T> {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        write!(f, "channel closed")
57    }
58}
59
60impl<T: std::fmt::Debug> std::error::Error for SendError<T> {}
61
62/// Sender side of an SSE channel.
63///
64/// Use this to send AG-UI events that will be streamed to connected clients.
65/// Events are serialized to JSON and formatted as SSE data frames.
66#[derive(Debug, Clone)]
67pub struct SseSender<StateT: AgentState = JsonValue> {
68    sender: mpsc::Sender<Event<StateT>>,
69}
70
71impl<StateT: AgentState> SseSender<StateT> {
72    /// Sends an event to the SSE stream.
73    ///
74    /// Returns an error if the receiver has been dropped (client disconnected).
75    pub async fn send(&self, event: Event<StateT>) -> Result<(), SendError<Event<StateT>>> {
76        self.sender.send(event).await.map_err(|e| SendError(e.0))
77    }
78
79    /// Sends multiple events to the SSE stream.
80    ///
81    /// Stops and returns an error on the first failed send.
82    pub async fn send_many(
83        &self,
84        events: impl IntoIterator<Item = Event<StateT>>,
85    ) -> Result<(), SendError<Event<StateT>>> {
86        for event in events {
87            self.send(event).await?;
88        }
89        Ok(())
90    }
91
92    /// Tries to send an event without waiting.
93    ///
94    /// Returns an error if the channel is full or closed.
95    pub fn try_send(&self, event: Event<StateT>) -> Result<(), SendError<Event<StateT>>> {
96        self.sender.try_send(event).map_err(|e| SendError(e.into_inner()))
97    }
98
99    /// Checks if the receiver is still connected.
100    pub fn is_closed(&self) -> bool {
101        self.sender.is_closed()
102    }
103}
104
105/// Handler side of an SSE channel.
106///
107/// This is converted to an axum SSE response that streams events to the client.
108/// Each event is serialized to JSON and sent as an SSE data frame.
109pub struct SseHandler<StateT: AgentState = JsonValue> {
110    receiver: mpsc::Receiver<Event<StateT>>,
111}
112
113impl<StateT: AgentState> SseHandler<StateT> {
114    /// Converts this handler into an axum SSE response.
115    ///
116    /// The response will stream events as they are sent through the corresponding
117    /// [`SseSender`]. The stream ends when the sender is dropped.
118    pub fn into_response(self) -> impl IntoResponse {
119        let stream = SseEventStream {
120            inner: ReceiverStream::new(self.receiver),
121        };
122
123        Sse::new(stream).keep_alive(KeepAlive::default())
124    }
125}
126
127/// Internal stream wrapper that converts Events to axum SSE events.
128struct SseEventStream<StateT: AgentState> {
129    inner: ReceiverStream<Event<StateT>>,
130}
131
132impl<StateT: AgentState> Stream for SseEventStream<StateT> {
133    type Item = Result<AxumSseEvent, Infallible>;
134
135    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136        match Pin::new(&mut self.inner).poll_next(cx) {
137            Poll::Ready(Some(event)) => {
138                // Serialize event to JSON
139                let json = match serde_json::to_string(&event) {
140                    Ok(json) => json,
141                    Err(e) => {
142                        // Log error and send error event
143                        eprintln!("SSE serialization error: {}", e);
144                        format!(r#"{{"type":"RUN_ERROR","message":"Serialization error: {}"}}"#, e)
145                    }
146                };
147
148                // Create SSE event with the event type as the SSE event name
149                let sse_event = AxumSseEvent::default()
150                    .event(event.event_type().as_str())
151                    .data(json);
152
153                Poll::Ready(Some(Ok(sse_event)))
154            }
155            Poll::Ready(None) => Poll::Ready(None),
156            Poll::Pending => Poll::Pending,
157        }
158    }
159}
160
161/// Creates a new SSE channel pair.
162///
163/// The `buffer` parameter controls how many events can be queued before
164/// sends will block (or fail for `try_send`).
165///
166/// # Arguments
167///
168/// * `buffer` - The capacity of the internal channel buffer
169///
170/// # Returns
171///
172/// A tuple of (`SseSender`, `SseHandler`) that are connected.
173///
174/// # Example
175///
176/// ```rust,ignore
177/// let (sender, handler) = sse::channel::<serde_json::Value>(32);
178/// ```
179pub fn channel<StateT: AgentState>(buffer: usize) -> (SseSender<StateT>, SseHandler<StateT>) {
180    let (tx, rx) = mpsc::channel(buffer);
181    (SseSender { sender: tx }, SseHandler { receiver: rx })
182}
183
184/// Serializes an event to SSE format.
185///
186/// Returns the event formatted as `data: {json}\n\n`.
187pub fn format_sse_event<StateT: AgentState>(event: &Event<StateT>) -> Result<String, ServerError> {
188    let json = serde_json::to_string(event)
189        .map_err(|e| ServerError::Serialization(e.to_string()))?;
190    Ok(format!("data: {}\n\n", json))
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use syncable_ag_ui_core::{
197        MessageId, RunErrorEvent, TextMessageContentEvent, TextMessageStartEvent,
198    };
199
200    #[tokio::test]
201    async fn test_channel_creation() {
202        let (sender, _handler) = channel::<JsonValue>(10);
203        assert!(!sender.is_closed());
204    }
205
206    #[tokio::test]
207    async fn test_send_event() {
208        let (sender, mut handler) = channel::<JsonValue>(10);
209
210        let event: Event = Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random()));
211
212        sender.send(event.clone()).await.unwrap();
213
214        // Receive from the handler's receiver directly for testing
215        let received = handler.receiver.recv().await.unwrap();
216        assert_eq!(received.event_type(), event.event_type());
217    }
218
219    #[tokio::test]
220    async fn test_send_many_events() {
221        let (sender, mut handler) = channel::<JsonValue>(10);
222
223        let events: Vec<Event> = vec![
224            Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())),
225            Event::TextMessageContent(TextMessageContentEvent::new_unchecked(
226                MessageId::random(),
227                "Hello",
228            )),
229            Event::RunError(RunErrorEvent::new("test error")),
230        ];
231
232        sender.send_many(events.clone()).await.unwrap();
233
234        // Verify all events received
235        for expected in &events {
236            let received = handler.receiver.recv().await.unwrap();
237            assert_eq!(received.event_type(), expected.event_type());
238        }
239    }
240
241    #[tokio::test]
242    async fn test_channel_close_detection() {
243        let (sender, handler) = channel::<JsonValue>(10);
244
245        // Drop the handler
246        drop(handler);
247
248        // Sender should detect closure
249        assert!(sender.is_closed());
250
251        // Send should fail
252        let event: Event = Event::RunError(RunErrorEvent::new("test"));
253        let result = sender.send(event).await;
254        assert!(result.is_err());
255    }
256
257    #[tokio::test]
258    async fn test_try_send() {
259        let (sender, _handler) = channel::<JsonValue>(2);
260
261        let event: Event = Event::RunError(RunErrorEvent::new("test"));
262
263        // First two should succeed (buffer size is 2)
264        assert!(sender.try_send(event.clone()).is_ok());
265        assert!(sender.try_send(event.clone()).is_ok());
266
267        // Third should fail (buffer full)
268        assert!(sender.try_send(event).is_err());
269    }
270
271    #[test]
272    fn test_format_sse_event() {
273        let event: Event = Event::RunError(RunErrorEvent::new("test error"));
274        let formatted = format_sse_event(&event).unwrap();
275
276        assert!(formatted.starts_with("data: "));
277        assert!(formatted.ends_with("\n\n"));
278        assert!(formatted.contains("\"type\":\"RUN_ERROR\""));
279        assert!(formatted.contains("\"message\":\"test error\""));
280    }
281
282    #[test]
283    fn test_format_sse_event_with_complex_event() {
284        let event: Event = Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random()));
285        let formatted = format_sse_event(&event).unwrap();
286
287        assert!(formatted.contains("\"type\":\"TEXT_MESSAGE_START\""));
288        assert!(formatted.contains("\"messageId\":"));
289        assert!(formatted.contains("\"role\":\"assistant\""));
290    }
291}