turul_http_mcp_server/
sse.rs

1//! Server-Sent Events (SSE) support for MCP
2
3use futures::stream;
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{RwLock, broadcast};
8use tokio_stream::Stream;
9use tracing::{debug, error};
10
11/// SSE event types
12#[derive(Debug, Clone)]
13pub enum SseEvent {
14    /// Connection established
15    Connected,
16    /// Data event with JSON payload
17    Data(Value),
18    /// Error event
19    Error(String),
20    /// Keep-alive ping
21    KeepAlive,
22}
23
24impl SseEvent {
25    /// Format as SSE message
26    pub fn format(&self) -> String {
27        match self {
28            SseEvent::Connected => {
29                // Use "message" event type for MCP Inspector compatibility
30                "event: message\ndata: {\"type\":\"connected\",\"message\":\"SSE connection established\"}\n\n".to_string()
31            }
32            SseEvent::Data(data) => {
33                // Use "message" event type for MCP Inspector compatibility
34                format!(
35                    "event: message\ndata: {}\n\n",
36                    serde_json::to_string(data).unwrap_or_else(|_| "{}".to_string())
37                )
38            }
39            SseEvent::Error(msg) => {
40                // Use "message" event type for MCP Inspector compatibility
41                format!(
42                    "event: message\ndata: {{\"error\":\"{}\"}}\n\n",
43                    msg.replace('"', "\\\"")
44                )
45            }
46            SseEvent::KeepAlive => {
47                // Keepalive events omit the event line (per SSE spec for compatibility)
48                ": keepalive\n\n".to_string()
49            }
50        }
51    }
52}
53
54/// SSE connection manager
55pub struct SseManager {
56    /// Broadcast channel for sending events to all connections
57    sender: broadcast::Sender<SseEvent>,
58    /// Connection registry
59    connections: Arc<RwLock<HashMap<String, SseConnection>>>,
60}
61
62/// Individual SSE connection
63#[derive(Debug)]
64pub struct SseConnection {
65    /// Connection ID
66    pub id: String,
67    /// Receiver for events
68    pub receiver: broadcast::Receiver<SseEvent>,
69}
70
71impl SseManager {
72    /// Create a new SSE manager
73    pub fn new() -> Self {
74        let (sender, _) = broadcast::channel(1000);
75        Self {
76            sender,
77            connections: Arc::new(RwLock::new(HashMap::new())),
78        }
79    }
80
81    /// Create a new SSE connection
82    pub async fn create_connection(&self, connection_id: String) -> SseConnection {
83        let receiver = self.sender.subscribe();
84        let connection = SseConnection {
85            id: connection_id.clone(),
86            receiver,
87        };
88
89        // Register the connection
90        {
91            let mut connections = self.connections.write().await;
92            connections.insert(
93                connection_id,
94                SseConnection {
95                    id: connection.id.clone(),
96                    receiver: self.sender.subscribe(),
97                },
98            );
99        }
100
101        debug!("SSE connection created: {}", connection.id);
102
103        // Send connected event
104        let _ = self.sender.send(SseEvent::Connected);
105
106        connection
107    }
108
109    /// Remove a connection
110    pub async fn remove_connection(&self, connection_id: &str) {
111        let mut connections = self.connections.write().await;
112        connections.remove(connection_id);
113        debug!("SSE connection removed: {}", connection_id);
114    }
115
116    /// Broadcast an event to all connections
117    pub async fn broadcast(&self, event: SseEvent) {
118        if let Err(err) = self.sender.send(event) {
119            error!("Failed to broadcast SSE event: {}", err);
120        }
121    }
122
123    /// Send data to all connections
124    pub async fn send_data(&self, data: Value) {
125        self.broadcast(SseEvent::Data(data)).await;
126    }
127
128    /// Send error to all connections
129    pub async fn send_error(&self, message: String) {
130        self.broadcast(SseEvent::Error(message)).await;
131    }
132
133    /// Send keep-alive ping
134    pub async fn send_keep_alive(&self) {
135        self.broadcast(SseEvent::KeepAlive).await;
136    }
137
138    /// Get number of active connections
139    pub async fn connection_count(&self) -> usize {
140        let connections = self.connections.read().await;
141        connections.len()
142    }
143}
144
145impl Default for SseManager {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151impl SseConnection {
152    /// Convert to a stream of SSE-formatted strings
153    pub fn into_stream(self) -> impl Stream<Item = Result<String, broadcast::error::RecvError>> {
154        stream::unfold(self, |mut connection| async move {
155            match connection.receiver.recv().await {
156                Ok(event) => {
157                    let formatted = event.format();
158                    Some((Ok(formatted), connection))
159                }
160                Err(err) => Some((Err(err), connection)),
161            }
162        })
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use serde_json::json;
170
171    #[test]
172    fn test_sse_event_format() {
173        // All events except keepalive use "event: message" for MCP Inspector compatibility
174        let connected = SseEvent::Connected;
175        assert!(connected.format().contains("event: message"));
176
177        let data = SseEvent::Data(json!({"message": "test"}));
178        assert!(data.format().contains("event: message"));
179
180        let error = SseEvent::Error("test error".to_string());
181        assert!(error.format().contains("event: message"));
182
183        // Keepalive uses SSE comment syntax (no event line)
184        let ping = SseEvent::KeepAlive;
185        assert!(!ping.format().contains("event:"));
186        assert!(ping.format().starts_with(":"));
187    }
188
189    #[tokio::test]
190    async fn test_sse_manager() {
191        let manager = SseManager::new();
192        assert_eq!(manager.connection_count().await, 0);
193
194        let _conn = manager.create_connection("test-123".to_string()).await;
195        assert_eq!(manager.connection_count().await, 1);
196
197        manager.remove_connection("test-123").await;
198        assert_eq!(manager.connection_count().await, 0);
199    }
200
201    #[tokio::test]
202    async fn test_broadcast() {
203        let manager = SseManager::new();
204        let mut conn = manager.create_connection("test-456".to_string()).await;
205
206        // First event should be Connected
207        if let Ok(event) = conn.receiver.recv().await {
208            assert!(matches!(event, SseEvent::Connected));
209        }
210
211        // Send a test event
212        manager.send_data(json!({"test": "message"})).await;
213
214        // The connection should receive the data event
215        if let Ok(event) = conn.receiver.recv().await {
216            match event {
217                SseEvent::Data(data) => {
218                    assert_eq!(data["test"], "message");
219                }
220                _ => panic!("Expected data event"),
221            }
222        }
223    }
224}