turul_http_mcp_server/
sse.rs

1//! Server-Sent Events (SSE) support for MCP
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{broadcast, RwLock};
6use tokio_stream::Stream;
7use futures::stream;
8use serde_json::Value;
9use tracing::{debug, error};
10
11/// SSE event types
12#[derive(Debug, Clone)]
13pub enum SseEvent {
14    /// Connection established
15    Connected,
16    /// Data event with JSON payload
17    Data(Value),
18    /// Error event
19    Error(String),
20    /// Keep-alive ping
21    KeepAlive,
22}
23
24impl SseEvent {
25    /// Format as SSE message
26    pub fn format(&self) -> String {
27        match self {
28            SseEvent::Connected => {
29                "event: connected\ndata: {\"type\":\"connected\",\"message\":\"SSE connection established\"}\n\n".to_string()
30            }
31            SseEvent::Data(data) => {
32                format!("event: data\ndata: {}\n\n", serde_json::to_string(data).unwrap_or_else(|_| "{}".to_string()))
33            }
34            SseEvent::Error(msg) => {
35                format!("event: error\ndata: {{\"error\":\"{}\"}}\n\n", msg.replace('"', "\\\""))
36            }
37            SseEvent::KeepAlive => {
38                "event: ping\ndata: {\"type\":\"ping\"}\n\n".to_string()
39            }
40        }
41    }
42}
43
44/// SSE connection manager
45pub struct SseManager {
46    /// Broadcast channel for sending events to all connections
47    sender: broadcast::Sender<SseEvent>,
48    /// Connection registry
49    connections: Arc<RwLock<HashMap<String, SseConnection>>>,
50}
51
52/// Individual SSE connection
53#[derive(Debug)]
54pub struct SseConnection {
55    /// Connection ID
56    pub id: String,
57    /// Receiver for events
58    pub receiver: broadcast::Receiver<SseEvent>,
59}
60
61impl SseManager {
62    /// Create a new SSE manager
63    pub fn new() -> Self {
64        let (sender, _) = broadcast::channel(1000);
65        Self {
66            sender,
67            connections: Arc::new(RwLock::new(HashMap::new())),
68        }
69    }
70
71    /// Create a new SSE connection
72    pub async fn create_connection(&self, connection_id: String) -> SseConnection {
73        let receiver = self.sender.subscribe();
74        let connection = SseConnection {
75            id: connection_id.clone(),
76            receiver,
77        };
78
79        // Register the connection
80        {
81            let mut connections = self.connections.write().await;
82            connections.insert(connection_id, SseConnection {
83                id: connection.id.clone(),
84                receiver: self.sender.subscribe(),
85            });
86        }
87
88        debug!("SSE connection created: {}", connection.id);
89        
90        // Send connected event
91        let _ = self.sender.send(SseEvent::Connected);
92
93        connection
94    }
95
96    /// Remove a connection
97    pub async fn remove_connection(&self, connection_id: &str) {
98        let mut connections = self.connections.write().await;
99        connections.remove(connection_id);
100        debug!("SSE connection removed: {}", connection_id);
101    }
102
103    /// Broadcast an event to all connections
104    pub async fn broadcast(&self, event: SseEvent) {
105        if let Err(err) = self.sender.send(event) {
106            error!("Failed to broadcast SSE event: {}", err);
107        }
108    }
109
110    /// Send data to all connections
111    pub async fn send_data(&self, data: Value) {
112        self.broadcast(SseEvent::Data(data)).await;
113    }
114
115    /// Send error to all connections
116    pub async fn send_error(&self, message: String) {
117        self.broadcast(SseEvent::Error(message)).await;
118    }
119
120    /// Send keep-alive ping
121    pub async fn send_keep_alive(&self) {
122        self.broadcast(SseEvent::KeepAlive).await;
123    }
124
125    /// Get number of active connections
126    pub async fn connection_count(&self) -> usize {
127        let connections = self.connections.read().await;
128        connections.len()
129    }
130}
131
132impl Default for SseManager {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138impl SseConnection {
139    /// Convert to a stream of SSE-formatted strings
140    pub fn into_stream(self) -> impl Stream<Item = Result<String, broadcast::error::RecvError>> {
141        stream::unfold(self, |mut connection| async move {
142            match connection.receiver.recv().await {
143                Ok(event) => {
144                    let formatted = event.format();
145                    Some((Ok(formatted), connection))
146                }
147                Err(err) => Some((Err(err), connection)),
148            }
149        })
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use serde_json::json;
157
158    #[test]
159    fn test_sse_event_format() {
160        let connected = SseEvent::Connected;
161        assert!(connected.format().contains("event: connected"));
162
163        let data = SseEvent::Data(json!({"message": "test"}));
164        assert!(data.format().contains("event: data"));
165
166        let error = SseEvent::Error("test error".to_string());
167        assert!(error.format().contains("event: error"));
168
169        let ping = SseEvent::KeepAlive;
170        assert!(ping.format().contains("event: ping"));
171    }
172
173    #[tokio::test]
174    async fn test_sse_manager() {
175        let manager = SseManager::new();
176        assert_eq!(manager.connection_count().await, 0);
177
178        let _conn = manager.create_connection("test-123".to_string()).await;
179        assert_eq!(manager.connection_count().await, 1);
180
181        manager.remove_connection("test-123").await;
182        assert_eq!(manager.connection_count().await, 0);
183    }
184
185    #[tokio::test]
186    async fn test_broadcast() {
187        let manager = SseManager::new();
188        let mut conn = manager.create_connection("test-456".to_string()).await;
189
190        // First event should be Connected
191        if let Ok(event) = conn.receiver.recv().await {
192            assert!(matches!(event, SseEvent::Connected));
193        }
194
195        // Send a test event
196        manager.send_data(json!({"test": "message"})).await;
197
198        // The connection should receive the data event
199        if let Ok(event) = conn.receiver.recv().await {
200            match event {
201                SseEvent::Data(data) => {
202                    assert_eq!(data["test"], "message");
203                }
204                _ => panic!("Expected data event"),
205            }
206        }
207    }
208}