syncable_ag_ui_server/transport/
sse.rs1use std::convert::Infallible;
38use std::pin::Pin;
39use std::task::{Context, Poll};
40
41use syncable_ag_ui_core::{AgentState, Event, JsonValue};
42use axum::response::sse::{Event as AxumSseEvent, KeepAlive, Sse};
43use axum::response::IntoResponse;
44use futures::Stream;
45use tokio::sync::mpsc;
46use tokio_stream::wrappers::ReceiverStream;
47
48use crate::error::ServerError;
49
50#[derive(Debug, Clone)]
52pub struct SendError<T>(pub T);
53
54impl<T> std::fmt::Display for SendError<T> {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "channel closed")
57 }
58}
59
60impl<T: std::fmt::Debug> std::error::Error for SendError<T> {}
61
62#[derive(Debug, Clone)]
67pub struct SseSender<StateT: AgentState = JsonValue> {
68 sender: mpsc::Sender<Event<StateT>>,
69}
70
71impl<StateT: AgentState> SseSender<StateT> {
72 pub async fn send(&self, event: Event<StateT>) -> Result<(), SendError<Event<StateT>>> {
76 self.sender.send(event).await.map_err(|e| SendError(e.0))
77 }
78
79 pub async fn send_many(
83 &self,
84 events: impl IntoIterator<Item = Event<StateT>>,
85 ) -> Result<(), SendError<Event<StateT>>> {
86 for event in events {
87 self.send(event).await?;
88 }
89 Ok(())
90 }
91
92 pub fn try_send(&self, event: Event<StateT>) -> Result<(), SendError<Event<StateT>>> {
96 self.sender.try_send(event).map_err(|e| SendError(e.into_inner()))
97 }
98
99 pub fn is_closed(&self) -> bool {
101 self.sender.is_closed()
102 }
103}
104
105pub struct SseHandler<StateT: AgentState = JsonValue> {
110 receiver: mpsc::Receiver<Event<StateT>>,
111}
112
113impl<StateT: AgentState> SseHandler<StateT> {
114 pub fn into_response(self) -> impl IntoResponse {
119 let stream = SseEventStream {
120 inner: ReceiverStream::new(self.receiver),
121 };
122
123 Sse::new(stream).keep_alive(KeepAlive::default())
124 }
125}
126
127struct SseEventStream<StateT: AgentState> {
129 inner: ReceiverStream<Event<StateT>>,
130}
131
132impl<StateT: AgentState> Stream for SseEventStream<StateT> {
133 type Item = Result<AxumSseEvent, Infallible>;
134
135 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136 match Pin::new(&mut self.inner).poll_next(cx) {
137 Poll::Ready(Some(event)) => {
138 let json = match serde_json::to_string(&event) {
140 Ok(json) => json,
141 Err(e) => {
142 eprintln!("SSE serialization error: {}", e);
144 format!(r#"{{"type":"RUN_ERROR","message":"Serialization error: {}"}}"#, e)
145 }
146 };
147
148 let sse_event = AxumSseEvent::default()
150 .event(event.event_type().as_str())
151 .data(json);
152
153 Poll::Ready(Some(Ok(sse_event)))
154 }
155 Poll::Ready(None) => Poll::Ready(None),
156 Poll::Pending => Poll::Pending,
157 }
158 }
159}
160
161pub fn channel<StateT: AgentState>(buffer: usize) -> (SseSender<StateT>, SseHandler<StateT>) {
180 let (tx, rx) = mpsc::channel(buffer);
181 (SseSender { sender: tx }, SseHandler { receiver: rx })
182}
183
184pub fn format_sse_event<StateT: AgentState>(event: &Event<StateT>) -> Result<String, ServerError> {
188 let json = serde_json::to_string(event)
189 .map_err(|e| ServerError::Serialization(e.to_string()))?;
190 Ok(format!("data: {}\n\n", json))
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use syncable_ag_ui_core::{
197 MessageId, RunErrorEvent, TextMessageContentEvent, TextMessageStartEvent,
198 };
199
200 #[tokio::test]
201 async fn test_channel_creation() {
202 let (sender, _handler) = channel::<JsonValue>(10);
203 assert!(!sender.is_closed());
204 }
205
206 #[tokio::test]
207 async fn test_send_event() {
208 let (sender, mut handler) = channel::<JsonValue>(10);
209
210 let event: Event = Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random()));
211
212 sender.send(event.clone()).await.unwrap();
213
214 let received = handler.receiver.recv().await.unwrap();
216 assert_eq!(received.event_type(), event.event_type());
217 }
218
219 #[tokio::test]
220 async fn test_send_many_events() {
221 let (sender, mut handler) = channel::<JsonValue>(10);
222
223 let events: Vec<Event> = vec![
224 Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())),
225 Event::TextMessageContent(TextMessageContentEvent::new_unchecked(
226 MessageId::random(),
227 "Hello",
228 )),
229 Event::RunError(RunErrorEvent::new("test error")),
230 ];
231
232 sender.send_many(events.clone()).await.unwrap();
233
234 for expected in &events {
236 let received = handler.receiver.recv().await.unwrap();
237 assert_eq!(received.event_type(), expected.event_type());
238 }
239 }
240
241 #[tokio::test]
242 async fn test_channel_close_detection() {
243 let (sender, handler) = channel::<JsonValue>(10);
244
245 drop(handler);
247
248 assert!(sender.is_closed());
250
251 let event: Event = Event::RunError(RunErrorEvent::new("test"));
253 let result = sender.send(event).await;
254 assert!(result.is_err());
255 }
256
257 #[tokio::test]
258 async fn test_try_send() {
259 let (sender, _handler) = channel::<JsonValue>(2);
260
261 let event: Event = Event::RunError(RunErrorEvent::new("test"));
262
263 assert!(sender.try_send(event.clone()).is_ok());
265 assert!(sender.try_send(event.clone()).is_ok());
266
267 assert!(sender.try_send(event).is_err());
269 }
270
271 #[test]
272 fn test_format_sse_event() {
273 let event: Event = Event::RunError(RunErrorEvent::new("test error"));
274 let formatted = format_sse_event(&event).unwrap();
275
276 assert!(formatted.starts_with("data: "));
277 assert!(formatted.ends_with("\n\n"));
278 assert!(formatted.contains("\"type\":\"RUN_ERROR\""));
279 assert!(formatted.contains("\"message\":\"test error\""));
280 }
281
282 #[test]
283 fn test_format_sse_event_with_complex_event() {
284 let event: Event = Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random()));
285 let formatted = format_sse_event(&event).unwrap();
286
287 assert!(formatted.contains("\"type\":\"TEXT_MESSAGE_START\""));
288 assert!(formatted.contains("\"messageId\":"));
289 assert!(formatted.contains("\"role\":\"assistant\""));
290 }
291}