syncable_ag_ui_server/transport/
ws.rs1use std::time::Duration;
47
48use syncable_ag_ui_core::{AgentState, Event, JsonValue};
49use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
50use axum::response::IntoResponse;
51use futures::{SinkExt, StreamExt};
52use tokio::sync::mpsc;
53use tokio::time::interval;
54
55use crate::error::ServerError;
56
57pub const DEFAULT_PING_INTERVAL: Duration = Duration::from_secs(30);
59
60#[derive(Debug, Clone)]
62pub struct SendError<T>(pub T);
63
64impl<T> std::fmt::Display for SendError<T> {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 write!(f, "WebSocket channel closed")
67 }
68}
69
70impl<T: std::fmt::Debug> std::error::Error for SendError<T> {}
71
72#[derive(Debug, Clone)]
74pub struct WsConfig {
75 pub ping_interval: Duration,
77 pub enable_ping: bool,
79}
80
81impl Default for WsConfig {
82 fn default() -> Self {
83 Self {
84 ping_interval: DEFAULT_PING_INTERVAL,
85 enable_ping: true,
86 }
87 }
88}
89
90impl WsConfig {
91 pub fn new() -> Self {
93 Self::default()
94 }
95
96 pub fn ping_interval(mut self, interval: Duration) -> Self {
98 self.ping_interval = interval;
99 self
100 }
101
102 pub fn disable_ping(mut self) -> Self {
104 self.enable_ping = false;
105 self
106 }
107}
108
109#[derive(Debug, Clone)]
114pub struct WsSender<StateT: AgentState = JsonValue> {
115 sender: mpsc::Sender<Event<StateT>>,
116}
117
118impl<StateT: AgentState> WsSender<StateT> {
119 pub async fn send(&self, event: Event<StateT>) -> Result<(), SendError<Event<StateT>>> {
123 self.sender.send(event).await.map_err(|e| SendError(e.0))
124 }
125
126 pub async fn send_many(
130 &self,
131 events: impl IntoIterator<Item = Event<StateT>>,
132 ) -> Result<(), SendError<Event<StateT>>> {
133 for event in events {
134 self.send(event).await?;
135 }
136 Ok(())
137 }
138
139 pub fn try_send(&self, event: Event<StateT>) -> Result<(), SendError<Event<StateT>>> {
143 self.sender
144 .try_send(event)
145 .map_err(|e| SendError(e.into_inner()))
146 }
147
148 pub fn is_closed(&self) -> bool {
150 self.sender.is_closed()
151 }
152}
153
154pub struct WsHandler<StateT: AgentState = JsonValue> {
158 receiver: mpsc::Receiver<Event<StateT>>,
159 config: WsConfig,
160}
161
162impl<StateT: AgentState> WsHandler<StateT> {
163 pub fn into_response(self, upgrade: WebSocketUpgrade) -> impl IntoResponse {
168 upgrade.on_upgrade(move |socket| self.handle_socket(socket))
169 }
170
171 async fn handle_socket(self, socket: WebSocket) {
173 let (mut ws_sender, mut ws_receiver) = socket.split();
174 let mut event_receiver = self.receiver;
175
176 let mut ping_interval = if self.config.enable_ping {
178 Some(interval(self.config.ping_interval))
179 } else {
180 None
181 };
182
183 loop {
184 tokio::select! {
185 event = event_receiver.recv() => {
187 match event {
188 Some(event) => {
189 let json = match serde_json::to_string(&event) {
191 Ok(json) => json,
192 Err(e) => {
193 eprintln!("WebSocket serialization error: {}", e);
194 continue;
195 }
196 };
197
198 if ws_sender.send(Message::Text(json.into())).await.is_err() {
200 break;
202 }
203 }
204 None => {
205 let _ = ws_sender.send(Message::Close(None)).await;
207 break;
208 }
209 }
210 }
211
212 _ = async {
214 if let Some(ref mut interval) = ping_interval {
215 interval.tick().await;
216 } else {
217 std::future::pending::<()>().await;
219 }
220 } => {
221 if ws_sender.send(Message::Ping(vec![].into())).await.is_err() {
222 break;
223 }
224 }
225
226 msg = ws_receiver.next() => {
228 match msg {
229 Some(Ok(Message::Pong(_))) => {
230 }
232 Some(Ok(Message::Close(_))) | None => {
233 break;
235 }
236 Some(Ok(_)) => {
237 }
240 Some(Err(_)) => {
241 break;
243 }
244 }
245 }
246 }
247 }
248 }
249}
250
251pub fn channel<StateT: AgentState>(buffer: usize) -> (WsSender<StateT>, WsHandler<StateT>) {
270 channel_with_config(buffer, WsConfig::default())
271}
272
273pub fn channel_with_config<StateT: AgentState>(
293 buffer: usize,
294 config: WsConfig,
295) -> (WsSender<StateT>, WsHandler<StateT>) {
296 let (tx, rx) = mpsc::channel(buffer);
297 (
298 WsSender { sender: tx },
299 WsHandler {
300 receiver: rx,
301 config,
302 },
303 )
304}
305
306pub fn format_ws_message<StateT: AgentState>(event: &Event<StateT>) -> Result<String, ServerError> {
310 serde_json::to_string(event).map_err(|e| ServerError::Serialization(e.to_string()))
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use syncable_ag_ui_core::{MessageId, RunErrorEvent, TextMessageContentEvent, TextMessageStartEvent};
317
318 #[tokio::test]
319 async fn test_channel_creation() {
320 let (sender, _handler) = channel::<JsonValue>(10);
321 assert!(!sender.is_closed());
322 }
323
324 #[tokio::test]
325 async fn test_channel_with_config() {
326 let config = WsConfig::new()
327 .ping_interval(Duration::from_secs(10))
328 .disable_ping();
329
330 let (sender, handler) = channel_with_config::<JsonValue>(10, config);
331 assert!(!sender.is_closed());
332 assert!(!handler.config.enable_ping);
333 assert_eq!(handler.config.ping_interval, Duration::from_secs(10));
334 }
335
336 #[tokio::test]
337 async fn test_send_event() {
338 let (sender, mut handler) = channel::<JsonValue>(10);
339
340 let event: Event = Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random()));
341
342 sender.send(event.clone()).await.unwrap();
343
344 let received = handler.receiver.recv().await.unwrap();
346 assert_eq!(received.event_type(), event.event_type());
347 }
348
349 #[tokio::test]
350 async fn test_send_many_events() {
351 let (sender, mut handler) = channel::<JsonValue>(10);
352
353 let events: Vec<Event> = vec![
354 Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())),
355 Event::TextMessageContent(TextMessageContentEvent::new_unchecked(
356 MessageId::random(),
357 "Hello",
358 )),
359 Event::RunError(RunErrorEvent::new("test error")),
360 ];
361
362 sender.send_many(events.clone()).await.unwrap();
363
364 for expected in &events {
366 let received = handler.receiver.recv().await.unwrap();
367 assert_eq!(received.event_type(), expected.event_type());
368 }
369 }
370
371 #[tokio::test]
372 async fn test_channel_close_detection() {
373 let (sender, handler) = channel::<JsonValue>(10);
374
375 drop(handler);
377
378 assert!(sender.is_closed());
380
381 let event: Event = Event::RunError(RunErrorEvent::new("test"));
383 let result = sender.send(event).await;
384 assert!(result.is_err());
385 }
386
387 #[tokio::test]
388 async fn test_try_send() {
389 let (sender, _handler) = channel::<JsonValue>(2);
390
391 let event: Event = Event::RunError(RunErrorEvent::new("test"));
392
393 assert!(sender.try_send(event.clone()).is_ok());
395 assert!(sender.try_send(event.clone()).is_ok());
396
397 assert!(sender.try_send(event).is_err());
399 }
400
401 #[test]
402 fn test_format_ws_message() {
403 let event: Event = Event::RunError(RunErrorEvent::new("test error"));
404 let message = format_ws_message(&event).unwrap();
405
406 assert!(message.contains("\"type\":\"RUN_ERROR\""));
407 assert!(message.contains("\"message\":\"test error\""));
408 }
409
410 #[test]
411 fn test_format_ws_message_complex() {
412 let event: Event =
413 Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random()));
414 let message = format_ws_message(&event).unwrap();
415
416 assert!(message.contains("\"type\":\"TEXT_MESSAGE_START\""));
417 assert!(message.contains("\"messageId\":"));
418 assert!(message.contains("\"role\":\"assistant\""));
419 }
420
421 #[test]
422 fn test_ws_config_default() {
423 let config = WsConfig::default();
424 assert!(config.enable_ping);
425 assert_eq!(config.ping_interval, DEFAULT_PING_INTERVAL);
426 }
427
428 #[test]
429 fn test_ws_config_builder() {
430 let config = WsConfig::new()
431 .ping_interval(Duration::from_secs(60))
432 .disable_ping();
433
434 assert!(!config.enable_ping);
435 assert_eq!(config.ping_interval, Duration::from_secs(60));
436 }
437
438 #[test]
439 fn test_send_error_display() {
440 let error: SendError<i32> = SendError(42);
441 assert_eq!(format!("{}", error), "WebSocket channel closed");
442 }
443}