turul_http_mcp_server/
sse.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{broadcast, RwLock};
6use tokio_stream::Stream;
7use futures::stream;
8use serde_json::Value;
9use tracing::{debug, error};
10
11#[derive(Debug, Clone)]
13pub enum SseEvent {
14 Connected,
16 Data(Value),
18 Error(String),
20 KeepAlive,
22}
23
24impl SseEvent {
25 pub fn format(&self) -> String {
27 match self {
28 SseEvent::Connected => {
29 "event: connected\ndata: {\"type\":\"connected\",\"message\":\"SSE connection established\"}\n\n".to_string()
30 }
31 SseEvent::Data(data) => {
32 format!("event: data\ndata: {}\n\n", serde_json::to_string(data).unwrap_or_else(|_| "{}".to_string()))
33 }
34 SseEvent::Error(msg) => {
35 format!("event: error\ndata: {{\"error\":\"{}\"}}\n\n", msg.replace('"', "\\\""))
36 }
37 SseEvent::KeepAlive => {
38 "event: ping\ndata: {\"type\":\"ping\"}\n\n".to_string()
39 }
40 }
41 }
42}
43
44pub struct SseManager {
46 sender: broadcast::Sender<SseEvent>,
48 connections: Arc<RwLock<HashMap<String, SseConnection>>>,
50}
51
52#[derive(Debug)]
54pub struct SseConnection {
55 pub id: String,
57 pub receiver: broadcast::Receiver<SseEvent>,
59}
60
61impl SseManager {
62 pub fn new() -> Self {
64 let (sender, _) = broadcast::channel(1000);
65 Self {
66 sender,
67 connections: Arc::new(RwLock::new(HashMap::new())),
68 }
69 }
70
71 pub async fn create_connection(&self, connection_id: String) -> SseConnection {
73 let receiver = self.sender.subscribe();
74 let connection = SseConnection {
75 id: connection_id.clone(),
76 receiver,
77 };
78
79 {
81 let mut connections = self.connections.write().await;
82 connections.insert(connection_id, SseConnection {
83 id: connection.id.clone(),
84 receiver: self.sender.subscribe(),
85 });
86 }
87
88 debug!("SSE connection created: {}", connection.id);
89
90 let _ = self.sender.send(SseEvent::Connected);
92
93 connection
94 }
95
96 pub async fn remove_connection(&self, connection_id: &str) {
98 let mut connections = self.connections.write().await;
99 connections.remove(connection_id);
100 debug!("SSE connection removed: {}", connection_id);
101 }
102
103 pub async fn broadcast(&self, event: SseEvent) {
105 if let Err(err) = self.sender.send(event) {
106 error!("Failed to broadcast SSE event: {}", err);
107 }
108 }
109
110 pub async fn send_data(&self, data: Value) {
112 self.broadcast(SseEvent::Data(data)).await;
113 }
114
115 pub async fn send_error(&self, message: String) {
117 self.broadcast(SseEvent::Error(message)).await;
118 }
119
120 pub async fn send_keep_alive(&self) {
122 self.broadcast(SseEvent::KeepAlive).await;
123 }
124
125 pub async fn connection_count(&self) -> usize {
127 let connections = self.connections.read().await;
128 connections.len()
129 }
130}
131
132impl Default for SseManager {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138impl SseConnection {
139 pub fn into_stream(self) -> impl Stream<Item = Result<String, broadcast::error::RecvError>> {
141 stream::unfold(self, |mut connection| async move {
142 match connection.receiver.recv().await {
143 Ok(event) => {
144 let formatted = event.format();
145 Some((Ok(formatted), connection))
146 }
147 Err(err) => Some((Err(err), connection)),
148 }
149 })
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use serde_json::json;
157
158 #[test]
159 fn test_sse_event_format() {
160 let connected = SseEvent::Connected;
161 assert!(connected.format().contains("event: connected"));
162
163 let data = SseEvent::Data(json!({"message": "test"}));
164 assert!(data.format().contains("event: data"));
165
166 let error = SseEvent::Error("test error".to_string());
167 assert!(error.format().contains("event: error"));
168
169 let ping = SseEvent::KeepAlive;
170 assert!(ping.format().contains("event: ping"));
171 }
172
173 #[tokio::test]
174 async fn test_sse_manager() {
175 let manager = SseManager::new();
176 assert_eq!(manager.connection_count().await, 0);
177
178 let _conn = manager.create_connection("test-123".to_string()).await;
179 assert_eq!(manager.connection_count().await, 1);
180
181 manager.remove_connection("test-123").await;
182 assert_eq!(manager.connection_count().await, 0);
183 }
184
185 #[tokio::test]
186 async fn test_broadcast() {
187 let manager = SseManager::new();
188 let mut conn = manager.create_connection("test-456".to_string()).await;
189
190 if let Ok(event) = conn.receiver.recv().await {
192 assert!(matches!(event, SseEvent::Connected));
193 }
194
195 manager.send_data(json!({"test": "message"})).await;
197
198 if let Ok(event) = conn.receiver.recv().await {
200 match event {
201 SseEvent::Data(data) => {
202 assert_eq!(data["test"], "message");
203 }
204 _ => panic!("Expected data event"),
205 }
206 }
207 }
208}