turul_http_mcp_server/
sse.rs1use futures::stream;
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{RwLock, broadcast};
8use tokio_stream::Stream;
9use tracing::{debug, error};
10
11#[derive(Debug, Clone)]
13pub enum SseEvent {
14 Connected,
16 Data(Value),
18 Error(String),
20 KeepAlive,
22}
23
24impl SseEvent {
25 pub fn format(&self) -> String {
27 match self {
28 SseEvent::Connected => {
29 "event: message\ndata: {\"type\":\"connected\",\"message\":\"SSE connection established\"}\n\n".to_string()
31 }
32 SseEvent::Data(data) => {
33 format!(
35 "event: message\ndata: {}\n\n",
36 serde_json::to_string(data).unwrap_or_else(|_| "{}".to_string())
37 )
38 }
39 SseEvent::Error(msg) => {
40 format!(
42 "event: message\ndata: {{\"error\":\"{}\"}}\n\n",
43 msg.replace('"', "\\\"")
44 )
45 }
46 SseEvent::KeepAlive => {
47 ": keepalive\n\n".to_string()
49 }
50 }
51 }
52}
53
54pub struct SseManager {
56 sender: broadcast::Sender<SseEvent>,
58 connections: Arc<RwLock<HashMap<String, SseConnection>>>,
60}
61
62#[derive(Debug)]
64pub struct SseConnection {
65 pub id: String,
67 pub receiver: broadcast::Receiver<SseEvent>,
69}
70
71impl SseManager {
72 pub fn new() -> Self {
74 let (sender, _) = broadcast::channel(1000);
75 Self {
76 sender,
77 connections: Arc::new(RwLock::new(HashMap::new())),
78 }
79 }
80
81 pub async fn create_connection(&self, connection_id: String) -> SseConnection {
83 let receiver = self.sender.subscribe();
84 let connection = SseConnection {
85 id: connection_id.clone(),
86 receiver,
87 };
88
89 {
91 let mut connections = self.connections.write().await;
92 connections.insert(
93 connection_id,
94 SseConnection {
95 id: connection.id.clone(),
96 receiver: self.sender.subscribe(),
97 },
98 );
99 }
100
101 debug!("SSE connection created: {}", connection.id);
102
103 let _ = self.sender.send(SseEvent::Connected);
105
106 connection
107 }
108
109 pub async fn remove_connection(&self, connection_id: &str) {
111 let mut connections = self.connections.write().await;
112 connections.remove(connection_id);
113 debug!("SSE connection removed: {}", connection_id);
114 }
115
116 pub async fn broadcast(&self, event: SseEvent) {
118 if let Err(err) = self.sender.send(event) {
119 error!("Failed to broadcast SSE event: {}", err);
120 }
121 }
122
123 pub async fn send_data(&self, data: Value) {
125 self.broadcast(SseEvent::Data(data)).await;
126 }
127
128 pub async fn send_error(&self, message: String) {
130 self.broadcast(SseEvent::Error(message)).await;
131 }
132
133 pub async fn send_keep_alive(&self) {
135 self.broadcast(SseEvent::KeepAlive).await;
136 }
137
138 pub async fn connection_count(&self) -> usize {
140 let connections = self.connections.read().await;
141 connections.len()
142 }
143}
144
145impl Default for SseManager {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151impl SseConnection {
152 pub fn into_stream(self) -> impl Stream<Item = Result<String, broadcast::error::RecvError>> {
154 stream::unfold(self, |mut connection| async move {
155 match connection.receiver.recv().await {
156 Ok(event) => {
157 let formatted = event.format();
158 Some((Ok(formatted), connection))
159 }
160 Err(err) => Some((Err(err), connection)),
161 }
162 })
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use serde_json::json;
170
171 #[test]
172 fn test_sse_event_format() {
173 let connected = SseEvent::Connected;
175 assert!(connected.format().contains("event: message"));
176
177 let data = SseEvent::Data(json!({"message": "test"}));
178 assert!(data.format().contains("event: message"));
179
180 let error = SseEvent::Error("test error".to_string());
181 assert!(error.format().contains("event: message"));
182
183 let ping = SseEvent::KeepAlive;
185 assert!(!ping.format().contains("event:"));
186 assert!(ping.format().starts_with(":"));
187 }
188
189 #[tokio::test]
190 async fn test_sse_manager() {
191 let manager = SseManager::new();
192 assert_eq!(manager.connection_count().await, 0);
193
194 let _conn = manager.create_connection("test-123".to_string()).await;
195 assert_eq!(manager.connection_count().await, 1);
196
197 manager.remove_connection("test-123").await;
198 assert_eq!(manager.connection_count().await, 0);
199 }
200
201 #[tokio::test]
202 async fn test_broadcast() {
203 let manager = SseManager::new();
204 let mut conn = manager.create_connection("test-456".to_string()).await;
205
206 if let Ok(event) = conn.receiver.recv().await {
208 assert!(matches!(event, SseEvent::Connected));
209 }
210
211 manager.send_data(json!({"test": "message"})).await;
213
214 if let Ok(event) = conn.receiver.recv().await {
216 match event {
217 SseEvent::Data(data) => {
218 assert_eq!(data["test"], "message");
219 }
220 _ => panic!("Expected data event"),
221 }
222 }
223 }
224}